From 9ad49340bff51e7aa42f41819db4ad8c86b8e953 Mon Sep 17 00:00:00 2001 From: Stream Date: Sat, 31 Jan 2026 00:39:57 +0800 Subject: [PATCH] refactor: remove union types Signed-off-by: Stream --- api/core/agent/agent_app_runner.py | 70 +++------ api/core/agent/base_agent_runner.py | 14 +- api/core/agent/entities.py | 15 +- api/core/agent/output_tools.py | 107 ++++++++----- api/core/agent/patterns/base.py | 3 +- api/core/agent/patterns/function_call.py | 124 +++++---------- api/core/agent/patterns/react.py | 94 ++++-------- api/core/agent/patterns/strategy_factory.py | 24 ++- api/core/workflow/nodes/llm/node.py | 145 +++++++----------- .../output_parser/test_cot_output_parser.py | 5 +- .../core/agent/patterns/test_base.py | 2 +- .../core/agent/test_agent_app_runner.py | 13 +- .../unit_tests/core/agent/test_entities.py | 2 +- 13 files changed, 257 insertions(+), 361 deletions(-) diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py index 3a0d96a1f9..0730827aca 100644 --- a/api/core/agent/agent_app_runner.py +++ b/api/core/agent/agent_app_runner.py @@ -1,10 +1,11 @@ +import json import logging from collections.abc import Generator from copy import deepcopy from typing import Any from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult +from core.agent.entities import AgentEntity, AgentLog, AgentResult from core.agent.patterns.strategy_factory import StrategyFactory from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent @@ -13,7 +14,6 @@ from core.model_runtime.entities import ( AssistantPromptMessage, LLMResult, LLMResultChunk, - LLMResultChunkDelta, LLMUsage, PromptMessage, PromptMessageContentType, @@ -118,7 +118,6 @@ class AgentAppRunner(BaseAgentRunner): prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, stop=app_generate_entity.model_conf.stop, - stream=False, ) # Consume generator and collect result @@ -256,49 +255,30 @@ class AgentAppRunner(BaseAgentRunner): raise # Process final result - if isinstance(result, AgentResult): - output_payload = result.output - if isinstance(output_payload, AgentResult.StructuredOutput): - if output_payload.output_kind == AgentOutputKind.ILLEGAL_OUTPUT: - raise ValueError("Agent returned illegal output") - if output_payload.output_kind not in { - AgentOutputKind.FINAL_OUTPUT_ANSWER, - AgentOutputKind.OUTPUT_TEXT, - }: - raise ValueError("Agent did not return text output") - if not output_payload.output_text: - raise ValueError("Agent returned empty text output") - final_answer = output_payload.output_text - else: - if not output_payload: - raise ValueError("Agent returned empty output") - final_answer = str(output_payload) - usage = result.usage or LLMUsage.empty_usage() + if not isinstance(result, AgentResult): + raise ValueError("Agent did not return AgentResult") + output_payload = result.output + if isinstance(output_payload, dict): + final_answer = json.dumps(output_payload, ensure_ascii=False) + elif isinstance(output_payload, str): + final_answer = output_payload + else: + raise ValueError("Final output is not a string or structured data.") + usage = result.usage or LLMUsage.empty_usage() - # Publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=self.model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=usage, - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) - - if False: - yield LLMResultChunk( - model="", - prompt_messages=[], - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=""), - usage=None, - ), - ) + # Publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index d1455a6e05..6a474f5bed 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -7,7 +7,7 @@ from typing import Union, cast from sqlalchemy import select from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext -from core.agent.output_tools import build_agent_output_tools, select_output_tool_names +from core.agent.output_tools import build_agent_output_tools from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -253,15 +253,9 @@ class BaseAgentRunner(AppRunner): # save tool entity tool_instances[dataset_tool.entity.identity.name] = dataset_tool - output_tools = build_agent_output_tools( - tenant_id=self.tenant_id, - invoke_from=self.application_generate_entity.invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT, - output_tool_names=select_output_tool_names( - structured_output_enabled=False, - include_illegal_output=True, - ), - ) + output_tools = build_agent_output_tools(tenant_id=self.tenant_id, + invoke_from=self.application_generate_entity.invoke_from, + tool_invoke_from=ToolInvokeFrom.AGENT) for tool in output_tools: tool_instances[tool.entity.identity.name] = tool diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index b606bb9048..785299221f 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -190,24 +190,11 @@ class AgentOutputKind(StrEnum): ILLEGAL_OUTPUT = "illegal_output" -OutputKind = AgentOutputKind - - class AgentResult(BaseModel): """ Agent execution result. """ - - class StructuredOutput(BaseModel): - """ - Structured output payload from output tools. - """ - - output_kind: AgentOutputKind - output_text: str | None = None - output_data: Mapping[str, Any] | None = None - - output: str | StructuredOutput = Field(default="", description="The generated output") + output: str | dict = Field(default="", description="The generated output") files: list[Any] = Field(default_factory=list, description="Files produced during execution") usage: Any | None = Field(default=None, description="LLM usage statistics") finish_reason: str | None = Field(default=None, description="Reason for completion") diff --git a/api/core/agent/output_tools.py b/api/core/agent/output_tools.py index 5aa8a26ed4..ad2e0b225a 100644 --- a/api/core/agent/output_tools.py +++ b/api/core/agent/output_tools.py @@ -1,11 +1,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeFrom, ToolParameter, ToolProviderType +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager OUTPUT_TOOL_PROVIDER = "agent_output" @@ -22,63 +23,87 @@ OUTPUT_TOOL_NAMES: Sequence[str] = ( ILLEGAL_OUTPUT_TOOL, ) -OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES) +TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL) -def select_output_tool_names( - *, - structured_output_enabled: bool, - include_illegal_output: bool = False, -) -> list[str]: - tool_names = [OUTPUT_TEXT_TOOL] - if structured_output_enabled: - tool_names.append(FINAL_STRUCTURED_OUTPUT_TOOL) - else: - tool_names.append(FINAL_OUTPUT_TOOL) - if include_illegal_output: - tool_names.append(ILLEGAL_OUTPUT_TOOL) - return tool_names +TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session." -def select_terminal_tool_name(*, structured_output_enabled: bool) -> str: +def is_terminal_output_tool(tool_name: str) -> bool: + return tool_name in TERMINAL_OUTPUT_TOOL_NAMES + + +def get_terminal_tool_name(structured_output_enabled: bool) -> str: return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL +OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES) + + def build_agent_output_tools( - *, tenant_id: str, invoke_from: InvokeFrom, tool_invoke_from: ToolInvokeFrom, - output_tool_names: Sequence[str], - structured_output_schema: dict[str, Any] | None = None, + structured_output_schema: Mapping[str, Any] | None = None, ) -> list[Tool]: - tools: list[Tool] = [] - tool_names: list[str] = [] - for tool_name in output_tool_names: - if tool_name not in OUTPUT_TOOL_NAME_SET: - raise ValueError(f"Unknown output tool name: {tool_name}") - if tool_name not in tool_names: - tool_names.append(tool_name) - - for tool_name in tool_names: - tool = ToolManager.get_tool_runtime( + def get_tool_runtime(_tool_name: str) -> Tool: + return ToolManager.get_tool_runtime( provider_type=ToolProviderType.BUILT_IN, provider_id=OUTPUT_TOOL_PROVIDER, - tool_name=tool_name, + tool_name=_tool_name, tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, ) - if tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and structured_output_schema: - tool.entity = tool.entity.model_copy(deep=True) - for parameter in tool.entity.parameters: - if parameter.name != "data": - continue - parameter.type = ToolParameter.ToolParameterType.OBJECT - parameter.form = ToolParameter.ToolParameterForm.LLM - parameter.required = True - parameter.input_schema = structured_output_schema - tools.append(tool) + tools: list[Tool] = [ + get_tool_runtime(OUTPUT_TEXT_TOOL), + get_tool_runtime(ILLEGAL_OUTPUT_TOOL), + ] + + if structured_output_schema: + raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL) + raw_tool.entity = raw_tool.entity.model_copy(deep=True) + data_parameter = ToolParameter( + name="data", + type=ToolParameter.ToolParameterType.OBJECT, + form=ToolParameter.ToolParameterForm.LLM, + required=True, + input_schema=dict(structured_output_schema), + label=I18nObject(en_US="__Data", zh_Hans="__Data"), + ) + raw_tool.entity.parameters = [data_parameter] + + def invoke_tool( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> ToolInvokeMessage: + data = tool_parameters["data"] + if not data: + return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required")) + if not isinstance(data, dict): + return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict")) + return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE)) + + raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage] + tools.append(raw_tool) + else: + raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL) + + def invoke_tool( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> ToolInvokeMessage: + return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE)) + raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage] + tools.append(raw_tool) return tools diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py index fbb8d304a5..bb80472f06 100644 --- a/api/core/agent/patterns/base.py +++ b/api/core/agent/patterns/base.py @@ -60,8 +60,7 @@ class AgentPattern(ABC): self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], - stop: list[str] = [], - stream: bool = True, + stop: list[str] ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: """Execute the agent strategy.""" pass diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py index 8da0308ba4..53c2159807 100644 --- a/api/core/agent/patterns/function_call.py +++ b/api/core/agent/patterns/function_call.py @@ -5,13 +5,10 @@ import uuid from collections.abc import Generator from typing import Any, Literal, Protocol, Union, cast -from core.agent.entities import AgentLog, AgentOutputKind, AgentResult +from core.agent.entities import AgentLog, AgentResult from core.agent.output_tools import ( - FINAL_OUTPUT_TOOL, - FINAL_STRUCTURED_OUTPUT_TOOL, ILLEGAL_OUTPUT_TOOL, - OUTPUT_TEXT_TOOL, - OUTPUT_TOOL_NAME_SET, + TERMINAL_OUTPUT_MESSAGE, ) from core.file import File from core.model_runtime.entities import ( @@ -36,21 +33,11 @@ class FunctionCallStrategy(AgentPattern): self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], - stop: list[str] = [], - stream: bool = True, + stop: list[str] ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: """Execute the function call agent strategy.""" # Convert tools to prompt format prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format() - tool_instance_names = {tool.entity.identity.name for tool in self.tools} - available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET} - if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names: - terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL - elif FINAL_OUTPUT_TOOL in available_output_tool_names: - terminal_tool_name = FINAL_OUTPUT_TOOL - else: - raise ValueError("No terminal output tool configured") - allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names # Initialize tracking iteration_step: int = 1 @@ -60,10 +47,10 @@ class FunctionCallStrategy(AgentPattern): messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy final_text: str = "" structured_output_payload: dict[str, Any] | None = None + final_tool_args: dict[str, Any] = {"!!!": "!!!"} output_text_payload: str | None = None finish_reason: str | None = None output_files: list[File] = [] # Track files produced by tools - terminal_output_seen = False class _LLMInvoker(Protocol): def invoke_llm( @@ -87,11 +74,7 @@ class FunctionCallStrategy(AgentPattern): data={}, ) yield round_log - # On last iteration, restrict tools to output tools - if iteration_step == max_iterations: - current_tools = [tool for tool in prompt_tools if tool.name in available_output_tool_names] - else: - current_tools = prompt_tools + model_log = self._create_log( label=f"{self.model_instance.model} Thought", log_type=AgentLog.LogType.THOUGHT, @@ -112,7 +95,7 @@ class FunctionCallStrategy(AgentPattern): chunks = invoker.invoke_llm( prompt_messages=messages, model_parameters=model_parameters, - tools=current_tools, + tools=prompt_tools, stop=stop, stream=False, user=self.context.user_id, @@ -124,19 +107,15 @@ class FunctionCallStrategy(AgentPattern): chunks, round_usage, model_log, emit_chunks=False ) - if not tool_calls: - if not allow_illegal_output: - raise ValueError("Model did not call any tools") - tool_calls = [ - ( - str(uuid.uuid4()), - ILLEGAL_OUTPUT_TOOL, - { - "raw": response_content, - }, - ) - ] - response_content = "" + if response_content: + replaced_tool_call = ( + str(uuid.uuid4()), + ILLEGAL_OUTPUT_TOOL, + { + "raw": response_content, + }, + ) + tool_calls.append(replaced_tool_call) messages.append(self._create_assistant_message("", tool_calls)) @@ -149,35 +128,23 @@ class FunctionCallStrategy(AgentPattern): if chunk_finish_reason: finish_reason = chunk_finish_reason + assert len(tool_calls) > 0 + # Process tool calls tool_outputs: dict[str, str] = {} - if tool_calls: - function_call_state = True - terminal_tool_seen = False - # Execute tools - for tool_call_id, tool_name, tool_args in tool_calls: - if tool_name == OUTPUT_TEXT_TOOL: - output_text_payload = self._format_output_text(tool_args.get("text")) - elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL: - data = tool_args.get("data") - structured_output_payload = cast(dict[str, Any] | None, data) - if tool_name == terminal_tool_name: - terminal_tool_seen = True - elif tool_name == FINAL_OUTPUT_TOOL: - final_text = self._format_output_text(tool_args.get("text")) - if tool_name == terminal_tool_name: - terminal_tool_seen = True - - tool_response, tool_files, _ = yield from self._handle_tool_call( - tool_name, tool_args, tool_call_id, messages, round_log - ) - tool_outputs[tool_name] = tool_response - # Track files produced by tools - output_files.extend(tool_files) - - if terminal_tool_seen: - terminal_output_seen = True + function_call_state = True + # Execute tools + for tool_call_id, tool_name, tool_args in tool_calls: + tool_response, tool_files, _ = yield from self._handle_tool_call( + tool_name, tool_args, tool_call_id, messages, round_log + ) + tool_outputs[tool_name] = tool_response + # Track files produced by tools + output_files.extend(tool_files) + if tool_response == TERMINAL_OUTPUT_MESSAGE: function_call_state = False + final_tool_args = tool_args + yield self._finish_log( round_log, data={ @@ -196,31 +163,16 @@ class FunctionCallStrategy(AgentPattern): # Return final result from core.agent.entities import AgentResult - output_payload: str | AgentResult.StructuredOutput - if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT, - output_text=None, - output_data=structured_output_payload, - ) - elif final_text: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, - output_text=final_text, - output_data=structured_output_payload, - ) - elif output_text_payload: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.OUTPUT_TEXT, - output_text=str(output_text_payload), - output_data=None, - ) + output_payload: str | dict + output_text = final_tool_args.get("text") + output_structured_payload = final_tool_args.get("data") + + if isinstance(output_structured_payload, dict): + output_payload = output_structured_payload + elif isinstance(output_text, str): + output_payload = output_text else: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.ILLEGAL_OUTPUT, - output_text="Model failed to produce a final output.", - output_data=None, - ) + raise ValueError("Final output is not a string or structured data.") return AgentResult( output=output_payload, diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py index ca4831b5c4..299e8c8baa 100644 --- a/api/core/agent/patterns/react.py +++ b/api/core/agent/patterns/react.py @@ -4,9 +4,9 @@ from __future__ import annotations import json from collections.abc import Generator -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any -from core.agent.entities import AgentLog, AgentOutputKind, AgentResult, AgentScratchpadUnit, ExecutionContext +from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext from core.agent.output_parser.cot_output_parser import CotAgentOutputParser from core.agent.output_tools import ( FINAL_OUTPUT_TOOL, @@ -60,10 +60,10 @@ class ReActStrategy(AgentPattern): def run( self, - prompt_messages: list[PromptMessage], + prompt_messages: + list[PromptMessage], model_parameters: dict[str, Any], - stop: list[str] = [], - stream: bool = True, + stop: list[str] ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: """Execute the ReAct agent strategy.""" # Initialize tracking @@ -137,16 +137,13 @@ class ReActStrategy(AgentPattern): messages_to_use = current_messages # Invoke model - chunks = cast( - Union[Generator[LLMResultChunk, None, None], LLMResult], - self.model_instance.invoke_llm( - prompt_messages=messages_to_use, - model_parameters=model_parameters, - stop=stop, - stream=False, - user=self.context.user_id or "", - callbacks=[], - ), + chunks = self.model_instance.invoke_llm( + prompt_messages=messages_to_use, + model_parameters=model_parameters, + stop=stop, + stream=False, + user=self.context.user_id or "", + callbacks=[], ) # Process response @@ -173,7 +170,7 @@ class ReActStrategy(AgentPattern): action_input={"raw": scratchpad.thought or ""}, ) scratchpad.action = illegal_action - scratchpad.action_str = json.dumps(illegal_action.to_dict()) + scratchpad.action_str = illegal_action.model_dump_json() react_state = True observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log) scratchpad.observation = observation @@ -218,33 +215,9 @@ class ReActStrategy(AgentPattern): # Return final result - from core.agent.entities import AgentResult + output_payload: str | dict - output_payload: str | AgentResult.StructuredOutput - if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT, - output_text=None, - output_data=structured_output_payload, - ) - elif final_text: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, - output_text=final_text, - output_data=structured_output_payload, - ) - elif output_text_payload: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.OUTPUT_TEXT, - output_text=str(output_text_payload), - output_data=structured_output_payload, - ) - else: - output_payload = AgentResult.StructuredOutput( - output_kind=AgentOutputKind.ILLEGAL_OUTPUT, - output_text="Model failed to produce a final output.", - output_data=structured_output_payload, - ) + # TODO return AgentResult( output=output_payload, @@ -336,7 +309,7 @@ class ReActStrategy(AgentPattern): def _handle_chunks( self, - chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + chunks: LLMResult, llm_usage: dict[str, Any], model_log: AgentLog, current_messages: list[PromptMessage], @@ -353,25 +326,20 @@ class ReActStrategy(AgentPattern): """ usage_dict: dict[str, Any] = {} - # Convert non-streaming to streaming format if needed - if isinstance(chunks, LLMResult): - # Create a generator from the LLMResult - def result_to_chunks() -> Generator[LLMResultChunk, None, None]: - yield LLMResultChunk( - model=chunks.model, - prompt_messages=chunks.prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=chunks.message, - usage=chunks.usage, - finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do - ), - system_fingerprint=chunks.system_fingerprint or "", - ) + def result_to_chunks() -> Generator[LLMResultChunk, None, None]: + yield LLMResultChunk( + model=chunks.model, + prompt_messages=chunks.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=chunks.message, + usage=chunks.usage, + finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do + ), + system_fingerprint=chunks.system_fingerprint or "", + ) - streaming_chunks = result_to_chunks() - else: - streaming_chunks = chunks + streaming_chunks = result_to_chunks() react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict) @@ -397,7 +365,7 @@ class ReActStrategy(AgentPattern): if emit_chunks: yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) - else: + elif isinstance(chunk, str): # Text chunk chunk_text = str(chunk) scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text @@ -405,6 +373,8 @@ class ReActStrategy(AgentPattern): if emit_chunks: yield self._create_text_chunk(chunk_text, current_messages) + else: + raise ValueError(f"Unexpected chunk type: {type(chunk)}") # Update usage if usage_dict.get("usage"): diff --git a/api/core/agent/patterns/strategy_factory.py b/api/core/agent/patterns/strategy_factory.py index ad26075291..945cded602 100644 --- a/api/core/agent/patterns/strategy_factory.py +++ b/api/core/agent/patterns/strategy_factory.py @@ -2,13 +2,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any from core.agent.entities import AgentEntity, ExecutionContext from core.file.models import File from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelFeature +from ...app.entities.app_invoke_entities import InvokeFrom +from ...tools.entities.tool_entities import ToolInvokeFrom +from ..output_tools import build_agent_output_tools from .base import AgentPattern, ToolInvokeHook from .function_call import FunctionCallStrategy from .react import ReActStrategy @@ -25,6 +29,10 @@ class StrategyFactory: @staticmethod def create_strategy( + *, + tenant_id: str, + invoke_from: InvokeFrom, + tool_invoke_from: ToolInvokeFrom, model_features: list[ModelFeature], model_instance: ModelInstance, context: ExecutionContext, @@ -35,11 +43,16 @@ class StrategyFactory: agent_strategy: AgentEntity.Strategy | None = None, tool_invoke_hook: ToolInvokeHook | None = None, instruction: str = "", + structured_output_schema: Mapping[str, Any] | None = None ) -> AgentPattern: """ Create an appropriate strategy based on model features. Args: + tenant_id: + invoke_from: + tool_invoke_from: + structured_output_schema: model_features: List of model features/capabilities model_instance: Model instance to use context: Execution context containing trace/audit information @@ -54,6 +67,15 @@ class StrategyFactory: Returns: AgentStrategy instance """ + output_tools = build_agent_output_tools( + tenant_id=tenant_id, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + structured_output_schema=structured_output_schema + ) + + tools.extend(output_tools) + # If explicit strategy is provided and it's Function Calling, try to use it if supported if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING: if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index d090542352..8749d4d29b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -12,13 +12,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import select -from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext -from core.agent.output_tools import build_agent_output_tools, select_output_tool_names +from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext +from core.agent.output_tools import build_agent_output_tools from core.agent.patterns import StrategyFactory from core.app.entities.app_asset_entities import AppAssetFileTree from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app_assets.constants import AppAssetsAttrs -from core.file import File, FileTransferMethod, FileType, file_manager +from core.file import FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output @@ -188,12 +188,11 @@ class LLMNode(Node[LLMNodeData]): def _run(self) -> Generator: node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} - clean_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - reasoning_content = "" # Initialize as empty string for consistency + usage: LLMUsage = LLMUsage.empty_usage() + finish_reason: str | None = None + reasoning_content: str = "" # Initialize as empty string for consistency clean_text = "" # Initialize clean_text to avoid UnboundLocalError - variable_pool = self.graph_runtime_state.variable_pool + variable_pool: VariablePool = self.graph_runtime_state.variable_pool try: # Parse prompt template to separate static messages and context references @@ -253,8 +252,9 @@ class LLMNode(Node[LLMNodeData]): ) query: str | None = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template + memory_config = self.node_data.memory + if memory_config: + query = memory_config.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): @@ -296,9 +296,16 @@ class LLMNode(Node[LLMNodeData]): sandbox=self.graph_runtime_state.sandbox, ) - # Variables for outputs - generation_data: LLMGenerationData | None = None - structured_output: LLMStructuredOutput | None = None + structured_output_schema: Mapping[str, Any] | None + + if self.node_data.structured_output_enabled: + if not self.node_data.structured_output: + raise ValueError("structured_output_enabled is True but structured_output is not set") + structured_output_schema = LLMNode.fetch_structured_output_schema( + structured_output=self.node_data.structured_output + ) + else: + structured_output_schema = None if self.node_data.computer_use: sandbox = self.graph_runtime_state.sandbox @@ -312,6 +319,7 @@ class LLMNode(Node[LLMNodeData]): stop=stop, variable_pool=variable_pool, tool_dependencies=tool_dependencies, + structured_output_schema=structured_output_schema ) elif self.tool_call_enabled: generator = self._invoke_llm_with_tools( @@ -322,6 +330,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, node_inputs=node_inputs, process_data=process_data, + structured_output_schema=structured_output_schema ) else: # Use traditional LLM invocation @@ -331,8 +340,7 @@ class LLMNode(Node[LLMNodeData]): prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, - structured_output_enabled=self._node_data.structured_output_enabled, - structured_output=self._node_data.structured_output, + structured_output_schema=structured_output_schema, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self._node_id, @@ -498,8 +506,7 @@ class LLMNode(Node[LLMNodeData]): prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, user_id: str, - structured_output_enabled: bool, - structured_output: Mapping[str, Any] | None = None, + structured_output_schema: Mapping[str, Any] | None, file_saver: LLMFileSaver, file_outputs: list[File], node_id: str, @@ -513,10 +520,7 @@ class LLMNode(Node[LLMNodeData]): if not model_schema: raise ValueError(f"Model schema not found for {node_data_model.name}") - if structured_output_enabled: - output_schema = LLMNode.fetch_structured_output_schema( - structured_output=structured_output or {}, - ) + if structured_output_schema: request_start_time = time.perf_counter() invoke_result = invoke_llm_with_structured_output( @@ -524,7 +528,7 @@ class LLMNode(Node[LLMNodeData]): model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=output_schema, + json_schema=structured_output_schema, model_parameters=node_data_model.completion_params, stop=list(stop or []), user=user_id, @@ -1876,6 +1880,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, node_inputs: dict[str, Any], process_data: dict[str, Any], + structured_output_schema: Mapping[str, Any] | None ) -> Generator[NodeEventBase, None, LLMGenerationData]: """Invoke LLM with tools support (from Agent V2). @@ -1892,20 +1897,23 @@ class LLMNode(Node[LLMNodeData]): # Use factory to create appropriate strategy strategy = StrategyFactory.create_strategy( + tenant_id=self.tenant_id, + invoke_from=self.invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, model_features=model_features, model_instance=model_instance, tools=tool_instances, files=prompt_files, max_iterations=self._node_data.max_iterations or 10, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), + structured_output_schema=structured_output_schema ) # Run strategy outputs = strategy.run( prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, - stop=list(stop or []), - stream=False, + stop=list(stop or []) ) result = yield from self._process_tool_outputs(outputs) @@ -1919,6 +1927,7 @@ class LLMNode(Node[LLMNodeData]): stop: Sequence[str] | None, variable_pool: VariablePool, tool_dependencies: ToolDependencies | None, + structured_output_schema: Mapping[str, Any] | None ) -> Generator[NodeEventBase, None, LLMGenerationData]: result: LLMGenerationData | None = None sandbox_output_files: list[File] = [] @@ -1927,37 +1936,25 @@ class LLMNode(Node[LLMNodeData]): with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session: prompt_files = self._extract_prompt_files(variable_pool) model_features = self._get_model_features(model_instance) - structured_output_schema = None - if self._node_data.structured_output_enabled: - structured_output_schema = LLMNode.fetch_structured_output_schema( - structured_output=self._node_data.structured_output or {}, - ) - output_tools = build_agent_output_tools( + + strategy = StrategyFactory.create_strategy( tenant_id=self.tenant_id, invoke_from=self.invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, - output_tool_names=select_output_tool_names( - structured_output_enabled=self._node_data.structured_output_enabled, - include_illegal_output=True, - ), - structured_output_schema=structured_output_schema, - ) - - strategy = StrategyFactory.create_strategy( model_features=model_features, model_instance=model_instance, - tools=[session.bash_tool, *output_tools], + tools=[session.bash_tool], files=prompt_files, max_iterations=self._node_data.max_iterations or 100, agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), + structured_output_schema=structured_output_schema ) outputs = strategy.run( prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, - stop=list(stop or []), - stream=False, + stop=list(stop or []) ) result = yield from self._process_tool_outputs(outputs) @@ -2055,16 +2052,9 @@ class LLMNode(Node[LLMNodeData]): structured_output=self._node_data.structured_output or {}, ) tool_instances.extend( - build_agent_output_tools( - tenant_id=self.tenant_id, - invoke_from=self.invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, - output_tool_names=select_output_tool_names( - structured_output_enabled=self._node_data.structured_output_enabled, - include_illegal_output=True, - ), - structured_output_schema=structured_output_schema, - ) + build_agent_output_tools(tenant_id=self.tenant_id, invoke_from=self.invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + structured_output_schema=structured_output_schema) ) return tool_instances @@ -2485,6 +2475,7 @@ class LLMNode(Node[LLMNodeData]): content_position = 0 tool_call_seen_index: dict[str, int] = {} for trace_segment in trace_state.trace_segments: + # FIXME: These if will never happen if trace_segment.type == "thought": sequence.append({"type": "reasoning", "index": reasoning_index}) reasoning_index += 1 @@ -2564,32 +2555,25 @@ class LLMNode(Node[LLMNodeData]): if not isinstance(exception.value, AgentResult): raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception state.agent.agent_result = exception.value - if not state.agent.agent_result: + agent_result = state.agent.agent_result + if not agent_result: raise ValueError("No agent result found in tool outputs") - output_payload = state.agent.agent_result.output - structured_output_data: Mapping[str, Any] | None = None - if isinstance(output_payload, AgentResult.StructuredOutput): - output_kind = output_payload.output_kind - if output_kind == AgentOutputKind.ILLEGAL_OUTPUT: - raise ValueError("Agent returned illegal output") - if output_kind in {AgentOutputKind.FINAL_OUTPUT_ANSWER, AgentOutputKind.OUTPUT_TEXT}: - if not output_payload.output_text: - raise ValueError("Agent returned empty text output") - state.aggregate.text = output_payload.output_text - elif output_kind == AgentOutputKind.FINAL_STRUCTURED_OUTPUT: - if output_payload.output_data is None: - raise ValueError("Agent returned empty structured output") - else: - raise ValueError("Agent returned unsupported output kind") - - if output_payload.output_data is not None: - if not isinstance(output_payload.output_data, Mapping): - raise ValueError("Agent returned invalid structured output") - structured_output_data = output_payload.output_data + output_payload = agent_result.output + if isinstance(output_payload, dict): + state.aggregate.structured_output = LLMStructuredOutput( + structured_output=convert_file_refs_in_output( + output=output_payload, + json_schema=LLMNode.fetch_structured_output_schema( + structured_output=self._node_data.structured_output or {}, + ), + tenant_id=self.tenant_id, + ) + ) + state.aggregate.text = json.dumps(output_payload) + elif isinstance(output_payload, str): + state.aggregate.text = output_payload else: - if not output_payload: - raise ValueError("Agent returned empty output") - state.aggregate.text = str(output_payload) + raise ValueError(f"Unexpected output type: {type(output_payload)}") state.aggregate.files = state.agent.agent_result.files if state.agent.agent_result.usage: @@ -2597,17 +2581,6 @@ class LLMNode(Node[LLMNodeData]): if state.agent.agent_result.finish_reason: state.aggregate.finish_reason = state.agent.agent_result.finish_reason - if structured_output_data is not None: - output_schema = LLMNode.fetch_structured_output_schema( - structured_output=self._node_data.structured_output or {}, - ) - converted_output = convert_file_refs_in_output( - output=structured_output_data, - json_schema=output_schema, - tenant_id=self.tenant_id, - ) - state.aggregate.structured_output = LLMStructuredOutput(structured_output=converted_output) - yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate) yield from self._close_streams() diff --git a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py index 4a613e35b0..aec077124d 100644 --- a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py +++ b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py @@ -1,4 +1,3 @@ -import json from collections.abc import Generator from core.agent.entities import AgentScratchpadUnit @@ -64,7 +63,7 @@ def test_cot_output_parser(): output += result elif isinstance(result, AgentScratchpadUnit.Action): if test_case["action"]: - assert result.to_dict() == test_case["action"] - output += json.dumps(result.to_dict()) + assert result.model_dump() == test_case["action"] + output += result.model_dump_json() if test_case["output"]: assert output == test_case["output"] diff --git a/api/tests/unit_tests/core/agent/patterns/test_base.py b/api/tests/unit_tests/core/agent/patterns/test_base.py index b0e0d44940..dd42b31bad 100644 --- a/api/tests/unit_tests/core/agent/patterns/test_base.py +++ b/api/tests/unit_tests/core/agent/patterns/test_base.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage class ConcreteAgentPattern(AgentPattern): """Concrete implementation of AgentPattern for testing.""" - def run(self, prompt_messages, model_parameters, stop=[], stream=True): + def run(self, prompt_messages, model_parameters, stop=[]): """Minimal implementation for testing.""" yield from [] diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py index 8214a56d3f..43d28d39b7 100644 --- a/api/tests/unit_tests/core/agent/test_agent_app_runner.py +++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest -from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentPromptEntity, AgentResult +from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.llm_entities import LLMUsage @@ -329,20 +329,15 @@ class TestAgentLogProcessing: ) result = AgentResult( - output=AgentResult.StructuredOutput( - output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER, - output_text="Final answer", - output_data=None, - ), + output="Final answer", files=[], usage=usage, finish_reason="stop", ) output_payload = result.output - assert isinstance(output_payload, AgentResult.StructuredOutput) - assert output_payload.output_text == "Final answer" - assert output_payload.output_kind == AgentOutputKind.FINAL_OUTPUT_ANSWER + assert isinstance(output_payload, str) + assert output_payload == "Final answer" assert result.files == [] assert result.usage == usage assert result.finish_reason == "stop" diff --git a/api/tests/unit_tests/core/agent/test_entities.py b/api/tests/unit_tests/core/agent/test_entities.py index 5136f48aab..61b0ac0052 100644 --- a/api/tests/unit_tests/core/agent/test_entities.py +++ b/api/tests/unit_tests/core/agent/test_entities.py @@ -153,7 +153,7 @@ class TestAgentScratchpadUnit: action_input={"query": "test"}, ) - result = action.to_dict() + result = action.model_dump() assert result == { "action": "search",