From 9042db301d376fa5a5ae386111a88f56a15aeae7 Mon Sep 17 00:00:00 2001 From: jyong Date: Wed, 20 Mar 2024 03:50:28 +0800 Subject: [PATCH] fix page content is empty --- .../knowledge_retrieval_node.py | 30 +++------- .../multi_dataset_function_call_router.py | 58 +++++++++++++++++++ .../structed_multi_dataset_router_agent.py | 14 ++++- 3 files changed, 79 insertions(+), 23 deletions(-) create mode 100644 api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 6e38849a26..5dd5195449 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -19,6 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.workflow.nodes.knowledge_retrieval.structed_multi_dataset_router_agent import ReactMultiDatasetRouter from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -214,32 +215,19 @@ class KnowledgeRetrievalNode(BaseNode): if ModelFeature.TOOL_CALL in features \ or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER - + dataset_id = None if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - return react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance, - self.user_id, self.tenant_id) + dataset_id = react_multi_dataset_router.invoke(query, tools, node_data, model_config, model_instance, + self.user_id, self.tenant_id) - prompt_messages = [ - SystemPromptMessage(content='You are a helpful AI assistant.'), - UserPromptMessage(content=query) - ] - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=tools, - stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - - if result.message.tool_calls: + elif planning_strategy == PlanningStrategy.ROUTER: + function_call_router = FunctionCallMultiDatasetRouter() + dataset_id = function_call_router.invoke(query, tools, model_config, model_instance) + if dataset_id: # get retrieval model config - function_call_name = result.message.tool_calls[0].function.name dataset = db.session.query(Dataset).filter( - Dataset.id == function_call_name + Dataset.id == dataset_id ).first() if dataset: retrieval_model_config = dataset.retrieval_model \ diff --git a/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py b/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py new file mode 100644 index 0000000000..9d723c5cee --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_function_call_router.py @@ -0,0 +1,58 @@ +from collections.abc import Generator, Sequence +from typing import Optional, Union + +from langchain import PromptTemplate +from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE +from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool, \ + SystemPromptMessage, UserPromptMessage +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.workflow.nodes.llm.llm_node import LLMNode + + +class FunctionCallMultiDatasetRouter: + + def invoke( + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + + ) -> Union[str, None]: + """Given input, decided what to do. + Returns: + Action specifying what tool to use. + """ + if len(dataset_tools) == 0: + return None + elif len(dataset_tools) == 1: + return dataset_tools[0].name + + try: + prompt_messages = [ + SystemPromptMessage(content='You are a helpful AI assistant.'), + UserPromptMessage(content=query) + ] + result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=dataset_tools, + stream=False, + model_parameters={ + 'temperature': 0.2, + 'top_p': 0.3, + 'max_tokens': 1500 + } + ) + if result.message.tool_calls: + # get retrieval model config + return result.message.tool_calls[0].function.name + return None + except Exception as e: + return None \ No newline at end of file diff --git a/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py b/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py index f694a01346..2882707783 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py +++ b/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py @@ -2,8 +2,11 @@ from collections.abc import Generator, Sequence from typing import Optional, Union from langchain import PromptTemplate +from langchain.agents import AgentOutputParser from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE +from langchain.agents.structured_chat.output_parser import StructuredChatOutputParserWithRetries from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX +from langchain.schema import AgentAction from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance @@ -13,6 +16,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.llm.llm_node import LLMNode +from pydantic import Field FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. @@ -126,7 +130,13 @@ class ReactMultiDatasetRouter: user_id=user_id, tenant_id=tenant_id ) - return result_text + output_parser: AgentOutputParser = Field( + default_factory=StructuredChatOutputParserWithRetries + ) + agent_decision = output_parser.parse(result_text) + if isinstance(agent_decision, AgentAction): + tool_inputs = agent_decision.tool_input + return tool_inputs def _invoke_llm(self, node_data: KnowledgeRetrievalNodeData, model_instance: ModelInstance, @@ -197,7 +207,7 @@ class ReactMultiDatasetRouter: ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: - tool_strings.append(f"{tool.name}: {tool.description}") + tool_strings.append(f"dataset_{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") formatted_tools = "\n".join(tool_strings) unique_tool_names = set(tool.name for tool in tools) tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)