From 884eeebe83b18db19c6d4c531f0bf5cd87cf343f Mon Sep 17 00:00:00 2001 From: jyong Date: Wed, 20 Mar 2024 04:00:50 +0800 Subject: [PATCH] fix react response --- .../structed_multi_dataset_router_agent.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py b/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py index 2882707783..33e30f10b4 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py +++ b/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py @@ -14,6 +14,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage +from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.llm.llm_node import LLMNode from pydantic import Field @@ -92,7 +93,7 @@ class ReactMultiDatasetRouter: suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - ) -> str: + ) -> Union[str, None]: if model_config.mode == "chat": prompt = self.create_chat_prompt( query=query, @@ -109,7 +110,7 @@ class ReactMultiDatasetRouter: format_instructions=format_instructions, input_variables=None ) - stop = model_config.stop + stop = ['Observation:'] # handle invoke result prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( @@ -130,13 +131,11 @@ class ReactMultiDatasetRouter: user_id=user_id, tenant_id=tenant_id ) - output_parser: AgentOutputParser = Field( - default_factory=StructuredChatOutputParserWithRetries - ) + output_parser = StructuredChatOutputParser() agent_decision = output_parser.parse(result_text) if isinstance(agent_decision, AgentAction): - tool_inputs = agent_decision.tool_input - return tool_inputs + return agent_decision.tool + return None def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData, model_instance: ModelInstance, @@ -207,7 +206,7 @@ class ReactMultiDatasetRouter: ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: - tool_strings.append(f"dataset_{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + tool_strings.append(f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") formatted_tools = "\n".join(tool_strings) unique_tool_names = set(tool.name for tool in tools) tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)