From 22b0a08a5f93c8902e8fc59aee55d02f66a0c9b3 Mon Sep 17 00:00:00 2001 From: Stream Date: Fri, 30 Jan 2026 08:15:42 +0800 Subject: [PATCH] fix: provides correct prompts, tools and terminal predicates Signed-off-by: Stream --- api/core/agent/base_agent_runner.py | 3 +- api/core/agent/output_tools.py | 29 +++++++- api/core/agent/patterns/function_call.py | 27 +++++++- api/core/agent/patterns/react.py | 68 +++++++++++++------ api/core/agent/prompt/template.py | 8 +-- .../agent_output/tools/illegal_output.py | 2 +- api/core/workflow/nodes/llm/node.py | 8 ++- 7 files changed, 112 insertions(+), 33 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index e5da2b3e12..f26e8c68e8 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -7,7 +7,7 @@ from typing import Union, cast from sqlalchemy import select from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext -from core.agent.output_tools import build_agent_output_tools +from core.agent.output_tools import build_agent_output_tools, select_output_tool_names from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -257,6 +257,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, invoke_from=self.application_generate_entity.invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, + output_tool_names=select_output_tool_names(structured_output_enabled=False), ) for tool in output_tools: tool_instances[tool.entity.identity.name] = tool diff --git a/api/core/agent/output_tools.py b/api/core/agent/output_tools.py index 0e8d18cc04..5aa8a26ed4 100644 --- a/api/core/agent/output_tools.py +++ b/api/core/agent/output_tools.py @@ -25,15 +25,42 @@ OUTPUT_TOOL_NAMES: Sequence[str] = ( OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES) +def select_output_tool_names( + *, + structured_output_enabled: bool, + include_illegal_output: bool = False, +) -> list[str]: + tool_names = [OUTPUT_TEXT_TOOL] + if structured_output_enabled: + tool_names.append(FINAL_STRUCTURED_OUTPUT_TOOL) + else: + tool_names.append(FINAL_OUTPUT_TOOL) + if include_illegal_output: + tool_names.append(ILLEGAL_OUTPUT_TOOL) + return tool_names + + +def select_terminal_tool_name(*, structured_output_enabled: bool) -> str: + return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL + + def build_agent_output_tools( *, tenant_id: str, invoke_from: InvokeFrom, tool_invoke_from: ToolInvokeFrom, + output_tool_names: Sequence[str], structured_output_schema: dict[str, Any] | None = None, ) -> list[Tool]: tools: list[Tool] = [] - for tool_name in OUTPUT_TOOL_NAMES: + tool_names: list[str] = [] + for tool_name in output_tool_names: + if tool_name not in OUTPUT_TOOL_NAME_SET: + raise ValueError(f"Unknown output tool name: {tool_name}") + if tool_name not in tool_names: + tool_names.append(tool_name) + + for tool_name in tool_names: tool = ToolManager.get_tool_runtime( provider_type=ToolProviderType.BUILT_IN, provider_id=OUTPUT_TOOL_PROVIDER, diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py index f29e86b3f4..da7b6422b3 100644 --- a/api/core/agent/patterns/function_call.py +++ b/api/core/agent/patterns/function_call.py @@ -42,6 +42,14 @@ class FunctionCallStrategy(AgentPattern): """Execute the function call agent strategy.""" # Convert tools to prompt format prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format() + available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET} + if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names: + terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL + elif FINAL_OUTPUT_TOOL in available_output_tool_names: + terminal_tool_name = FINAL_OUTPUT_TOOL + else: + raise ValueError("No terminal output tool configured") + allow_illegal_output = ILLEGAL_OUTPUT_TOOL in available_output_tool_names # Initialize tracking iteration_step: int = 1 @@ -54,6 +62,7 @@ class FunctionCallStrategy(AgentPattern): output_text_payload: str | None = None finish_reason: str | None = None output_files: list[File] = [] # Track files produced by tools + terminal_output_seen = False class _LLMInvoker(Protocol): def invoke_llm( @@ -79,7 +88,7 @@ class FunctionCallStrategy(AgentPattern): yield round_log # On last iteration, restrict tools to output tools if iteration_step == max_iterations: - current_tools = [tool for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET] + current_tools = [tool for tool in prompt_tools if tool.name in available_output_tool_names] else: current_tools = prompt_tools model_log = self._create_log( @@ -115,6 +124,8 @@ class FunctionCallStrategy(AgentPattern): ) if not tool_calls: + if not allow_illegal_output: + raise ValueError("Model did not call any tools") tool_calls = [ ( str(uuid.uuid4()), @@ -149,9 +160,12 @@ class FunctionCallStrategy(AgentPattern): elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL: data = tool_args.get("data") structured_output_payload = cast(dict[str, Any] | None, data) + if tool_name == terminal_tool_name: + terminal_tool_seen = True elif tool_name == FINAL_OUTPUT_TOOL: final_text = self._format_output_text(tool_args.get("text")) - terminal_tool_seen = True + if tool_name == terminal_tool_name: + terminal_tool_seen = True tool_response, tool_files, _ = yield from self._handle_tool_call( tool_name, tool_args, tool_call_id, messages, round_log @@ -161,6 +175,7 @@ class FunctionCallStrategy(AgentPattern): output_files.extend(tool_files) if terminal_tool_seen: + terminal_output_seen = True function_call_state = False yield self._finish_log( round_log, @@ -181,7 +196,13 @@ class FunctionCallStrategy(AgentPattern): from core.agent.entities import AgentResult output_payload: str | AgentResult.StructuredOutput - if final_text: + if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT, + output_text=None, + output_data=structured_output_payload, + ) + elif final_text: output_payload = AgentResult.StructuredOutput( output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, output_text=final_text, diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py index 207a824742..14986093cf 100644 --- a/api/core/agent/patterns/react.py +++ b/api/core/agent/patterns/react.py @@ -77,6 +77,17 @@ class ReActStrategy(AgentPattern): structured_output_payload: dict[str, Any] | None = None output_text_payload: str | None = None finish_reason: str | None = None + terminal_output_seen = False + available_output_tool_names = { + tool.entity.identity.name for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET + } + if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names: + terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL + elif FINAL_OUTPUT_TOOL in available_output_tool_names: + terminal_tool_name = FINAL_OUTPUT_TOOL + else: + raise ValueError("No terminal output tool configured") + allow_illegal_output = ILLEGAL_OUTPUT_TOOL in available_output_tool_names # Add "Observation" to stop sequences if "Observation" not in stop: @@ -95,7 +106,9 @@ class ReActStrategy(AgentPattern): # Build prompt with tool restrictions on last iteration if iteration_step == max_iterations: - tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name in OUTPUT_TOOL_NAME_SET] + tools_for_prompt = [ + tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names + ] else: tools_for_prompt = self.tools current_messages = self._build_prompt_with_react_format( @@ -150,6 +163,8 @@ class ReActStrategy(AgentPattern): # Check if we have an action to execute if scratchpad.action is None: + if not allow_illegal_output: + raise ValueError("Model did not call any tools") illegal_action = AgentScratchpadUnit.Action( action_name=ILLEGAL_OUTPUT_TOOL, action_input={"raw": scratchpad.thought or ""}, @@ -162,33 +177,29 @@ class ReActStrategy(AgentPattern): output_files.extend(tool_files) else: action_name = scratchpad.action.action_name - if action_name == FINAL_OUTPUT_TOOL: + if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict): + output_text_payload = scratchpad.action.action_input.get("text") + elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(scratchpad.action.action_input, dict): + data = scratchpad.action.action_input.get("data") + if isinstance(data, dict): + structured_output_payload = data + elif action_name == FINAL_OUTPUT_TOOL: if isinstance(scratchpad.action.action_input, dict): final_text = self._format_output_text(scratchpad.action.action_input.get("text")) else: final_text = self._format_output_text(scratchpad.action.action_input) - observation, tool_files = yield from self._handle_tool_call( - scratchpad.action, current_messages, round_log - ) - scratchpad.observation = observation - output_files.extend(tool_files) + + observation, tool_files = yield from self._handle_tool_call( + scratchpad.action, current_messages, round_log + ) + scratchpad.observation = observation + output_files.extend(tool_files) + + if action_name == terminal_tool_name: + terminal_output_seen = True react_state = False else: - if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict): - output_text_payload = scratchpad.action.action_input.get("text") - elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance( - scratchpad.action.action_input, dict - ): - data = scratchpad.action.action_input.get("data") - if isinstance(data, dict): - structured_output_payload = data - react_state = True - observation, tool_files = yield from self._handle_tool_call( - scratchpad.action, current_messages, round_log - ) - scratchpad.observation = observation - output_files.extend(tool_files) yield self._finish_log( round_log, @@ -207,7 +218,13 @@ class ReActStrategy(AgentPattern): from core.agent.entities import AgentResult output_payload: str | AgentResult.StructuredOutput - if final_text: + if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen: + output_payload = AgentResult.StructuredOutput( + output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT, + output_text=None, + output_data=structured_output_payload, + ) + elif final_text: output_payload = AgentResult.StructuredOutput( output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, output_text=final_text, @@ -268,12 +285,19 @@ class ReActStrategy(AgentPattern): tools_str = "No tools available" tool_names_str = "" + final_tool_name = FINAL_OUTPUT_TOOL + if FINAL_STRUCTURED_OUTPUT_TOOL in tool_names: + final_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL + if final_tool_name not in tool_names: + raise ValueError("No terminal output tool available for prompt") + # Replace placeholders in the existing system prompt updated_content = msg.content assert isinstance(updated_content, str) updated_content = updated_content.replace("{{instruction}}", instruction) updated_content = updated_content.replace("{{tools}}", tools_str) updated_content = updated_content.replace("{{tool_names}}", tool_names_str) + updated_content = updated_content.replace("{{final_tool_name}}", final_tool_name) # Create new SystemPromptMessage with updated content messages[i] = SystemPromptMessage(content=updated_content) diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index e46d504e30..0bbc9007b7 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -7,7 +7,7 @@ You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish. +Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish. Provide only ONE action per $JSON_BLOB, as shown: @@ -32,7 +32,7 @@ Thought: I know what to respond Action: ``` { - "action": "final_output_answer", + "action": "{{final_tool_name}}", "action_input": { "text": "Final response to human" } @@ -58,7 +58,7 @@ You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: {{tool_names}}. You must call "final_output_answer" to finish. +Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish. Provide only ONE action per $JSON_BLOB, as shown: @@ -83,7 +83,7 @@ Thought: I know what to respond Action: ``` { - "action": "final_output_answer", + "action": "{{final_tool_name}}", "action_input": { "text": "Final response to human" } diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py index 27301d61b3..2276c527e9 100644 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py +++ b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py @@ -16,6 +16,6 @@ class IllegalOutputTool(BuiltinTool): ) -> Generator[ToolInvokeMessage, None, None]: message = ( "Protocol violation: do not output plain text. " - "Call output_text, final_structured_output, then final_output_answer." + "Call an output tool and finish with the configured terminal tool." ) yield self.create_text_message(message) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 8081055dad..628aaaa334 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import select from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext -from core.agent.output_tools import build_agent_output_tools +from core.agent.output_tools import build_agent_output_tools, select_output_tool_names from core.agent.patterns import StrategyFactory from core.app.entities.app_asset_entities import AppAssetFileTree from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -1935,6 +1935,9 @@ class LLMNode(Node[LLMNodeData]): tenant_id=self.tenant_id, invoke_from=self.invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + output_tool_names=select_output_tool_names( + structured_output_enabled=self._node_data.structured_output_enabled + ), structured_output_schema=structured_output_schema, ) @@ -2037,6 +2040,9 @@ class LLMNode(Node[LLMNodeData]): tenant_id=self.tenant_id, invoke_from=self.invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + output_tool_names=select_output_tool_names( + structured_output_enabled=self._node_data.structured_output_enabled + ), structured_output_schema=structured_output_schema, ) )