refactor: remove union types

Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
Stream 2026-01-31 00:39:57 +08:00
parent a87560d667
commit 9ad49340bf
No known key found for this signature in database
GPG Key ID: 033728094B100D70
13 changed files with 257 additions and 361 deletions

View File

@ -1,10 +1,11 @@
import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult
from core.agent.entities import AgentEntity, AgentLog, AgentResult
from core.agent.patterns.strategy_factory import StrategyFactory
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
@ -13,7 +14,6 @@ from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContentType,
@ -118,7 +118,6 @@ class AgentAppRunner(BaseAgentRunner):
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=False,
)
# Consume generator and collect result
@ -256,49 +255,30 @@ class AgentAppRunner(BaseAgentRunner):
raise
# Process final result
if isinstance(result, AgentResult):
output_payload = result.output
if isinstance(output_payload, AgentResult.StructuredOutput):
if output_payload.output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
raise ValueError("Agent returned illegal output")
if output_payload.output_kind not in {
AgentOutputKind.FINAL_OUTPUT_ANSWER,
AgentOutputKind.OUTPUT_TEXT,
}:
raise ValueError("Agent did not return text output")
if not output_payload.output_text:
raise ValueError("Agent returned empty text output")
final_answer = output_payload.output_text
else:
if not output_payload:
raise ValueError("Agent returned empty output")
final_answer = str(output_payload)
usage = result.usage or LLMUsage.empty_usage()
if not isinstance(result, AgentResult):
raise ValueError("Agent did not return AgentResult")
output_payload = result.output
if isinstance(output_payload, dict):
final_answer = json.dumps(output_payload, ensure_ascii=False)
elif isinstance(output_payload, str):
final_answer = output_payload
else:
raise ValueError("Final output is not a string or structured data.")
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
if False:
yield LLMResultChunk(
model="",
prompt_messages=[],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=""),
usage=None,
),
)
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""

View File

@ -7,7 +7,7 @@ from typing import Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
from core.agent.output_tools import build_agent_output_tools, select_output_tool_names
from core.agent.output_tools import build_agent_output_tools
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -253,15 +253,9 @@ class BaseAgentRunner(AppRunner):
# save tool entity
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
output_tools = build_agent_output_tools(
tenant_id=self.tenant_id,
invoke_from=self.application_generate_entity.invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
output_tool_names=select_output_tool_names(
structured_output_enabled=False,
include_illegal_output=True,
),
)
output_tools = build_agent_output_tools(tenant_id=self.tenant_id,
invoke_from=self.application_generate_entity.invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT)
for tool in output_tools:
tool_instances[tool.entity.identity.name] = tool

View File

@ -190,24 +190,11 @@ class AgentOutputKind(StrEnum):
ILLEGAL_OUTPUT = "illegal_output"
OutputKind = AgentOutputKind
class AgentResult(BaseModel):
"""
Agent execution result.
"""
class StructuredOutput(BaseModel):
"""
Structured output payload from output tools.
"""
output_kind: AgentOutputKind
output_text: str | None = None
output_data: Mapping[str, Any] | None = None
output: str | StructuredOutput = Field(default="", description="The generated output")
output: str | dict = Field(default="", description="The generated output")
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
usage: Any | None = Field(default=None, description="LLM usage statistics")
finish_reason: str | None = Field(default=None, description="Reason for completion")

View File

@ -1,11 +1,12 @@
from __future__ import annotations
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolParameter, ToolProviderType
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
OUTPUT_TOOL_PROVIDER = "agent_output"
@ -22,63 +23,87 @@ OUTPUT_TOOL_NAMES: Sequence[str] = (
ILLEGAL_OUTPUT_TOOL,
)
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL)
def select_output_tool_names(
*,
structured_output_enabled: bool,
include_illegal_output: bool = False,
) -> list[str]:
tool_names = [OUTPUT_TEXT_TOOL]
if structured_output_enabled:
tool_names.append(FINAL_STRUCTURED_OUTPUT_TOOL)
else:
tool_names.append(FINAL_OUTPUT_TOOL)
if include_illegal_output:
tool_names.append(ILLEGAL_OUTPUT_TOOL)
return tool_names
TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session."
def select_terminal_tool_name(*, structured_output_enabled: bool) -> str:
def is_terminal_output_tool(tool_name: str) -> bool:
return tool_name in TERMINAL_OUTPUT_TOOL_NAMES
def get_terminal_tool_name(structured_output_enabled: bool) -> str:
return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
def build_agent_output_tools(
*,
tenant_id: str,
invoke_from: InvokeFrom,
tool_invoke_from: ToolInvokeFrom,
output_tool_names: Sequence[str],
structured_output_schema: dict[str, Any] | None = None,
structured_output_schema: Mapping[str, Any] | None = None,
) -> list[Tool]:
tools: list[Tool] = []
tool_names: list[str] = []
for tool_name in output_tool_names:
if tool_name not in OUTPUT_TOOL_NAME_SET:
raise ValueError(f"Unknown output tool name: {tool_name}")
if tool_name not in tool_names:
tool_names.append(tool_name)
for tool_name in tool_names:
tool = ToolManager.get_tool_runtime(
def get_tool_runtime(_tool_name: str) -> Tool:
return ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id=OUTPUT_TOOL_PROVIDER,
tool_name=tool_name,
tool_name=_tool_name,
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
if tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and structured_output_schema:
tool.entity = tool.entity.model_copy(deep=True)
for parameter in tool.entity.parameters:
if parameter.name != "data":
continue
parameter.type = ToolParameter.ToolParameterType.OBJECT
parameter.form = ToolParameter.ToolParameterForm.LLM
parameter.required = True
parameter.input_schema = structured_output_schema
tools.append(tool)
tools: list[Tool] = [
get_tool_runtime(OUTPUT_TEXT_TOOL),
get_tool_runtime(ILLEGAL_OUTPUT_TOOL),
]
if structured_output_schema:
raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
raw_tool.entity = raw_tool.entity.model_copy(deep=True)
data_parameter = ToolParameter(
name="data",
type=ToolParameter.ToolParameterType.OBJECT,
form=ToolParameter.ToolParameterForm.LLM,
required=True,
input_schema=dict(structured_output_schema),
label=I18nObject(en_US="__Data", zh_Hans="__Data"),
)
raw_tool.entity.parameters = [data_parameter]
def invoke_tool(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> ToolInvokeMessage:
data = tool_parameters["data"]
if not data:
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required"))
if not isinstance(data, dict):
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict"))
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
tools.append(raw_tool)
else:
raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
def invoke_tool(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> ToolInvokeMessage:
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
tools.append(raw_tool)
return tools

View File

@ -60,8 +60,7 @@ class AgentPattern(ABC):
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
stop: list[str]
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass

View File

@ -5,13 +5,10 @@ import uuid
from collections.abc import Generator
from typing import Any, Literal, Protocol, Union, cast
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
from core.agent.entities import AgentLog, AgentResult
from core.agent.output_tools import (
FINAL_OUTPUT_TOOL,
FINAL_STRUCTURED_OUTPUT_TOOL,
ILLEGAL_OUTPUT_TOOL,
OUTPUT_TEXT_TOOL,
OUTPUT_TOOL_NAME_SET,
TERMINAL_OUTPUT_MESSAGE,
)
from core.file import File
from core.model_runtime.entities import (
@ -36,21 +33,11 @@ class FunctionCallStrategy(AgentPattern):
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
stop: list[str]
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
tool_instance_names = {tool.entity.identity.name for tool in self.tools}
available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET}
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
terminal_tool_name = FINAL_OUTPUT_TOOL
else:
raise ValueError("No terminal output tool configured")
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
# Initialize tracking
iteration_step: int = 1
@ -60,10 +47,10 @@ class FunctionCallStrategy(AgentPattern):
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
structured_output_payload: dict[str, Any] | None = None
final_tool_args: dict[str, Any] = {"!!!": "!!!"}
output_text_payload: str | None = None
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
terminal_output_seen = False
class _LLMInvoker(Protocol):
def invoke_llm(
@ -87,11 +74,7 @@ class FunctionCallStrategy(AgentPattern):
data={},
)
yield round_log
# On last iteration, restrict tools to output tools
if iteration_step == max_iterations:
current_tools = [tool for tool in prompt_tools if tool.name in available_output_tool_names]
else:
current_tools = prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
@ -112,7 +95,7 @@ class FunctionCallStrategy(AgentPattern):
chunks = invoker.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
tools=prompt_tools,
stop=stop,
stream=False,
user=self.context.user_id,
@ -124,19 +107,15 @@ class FunctionCallStrategy(AgentPattern):
chunks, round_usage, model_log, emit_chunks=False
)
if not tool_calls:
if not allow_illegal_output:
raise ValueError("Model did not call any tools")
tool_calls = [
(
str(uuid.uuid4()),
ILLEGAL_OUTPUT_TOOL,
{
"raw": response_content,
},
)
]
response_content = ""
if response_content:
replaced_tool_call = (
str(uuid.uuid4()),
ILLEGAL_OUTPUT_TOOL,
{
"raw": response_content,
},
)
tool_calls.append(replaced_tool_call)
messages.append(self._create_assistant_message("", tool_calls))
@ -149,35 +128,23 @@ class FunctionCallStrategy(AgentPattern):
if chunk_finish_reason:
finish_reason = chunk_finish_reason
assert len(tool_calls) > 0
# Process tool calls
tool_outputs: dict[str, str] = {}
if tool_calls:
function_call_state = True
terminal_tool_seen = False
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
if tool_name == OUTPUT_TEXT_TOOL:
output_text_payload = self._format_output_text(tool_args.get("text"))
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
data = tool_args.get("data")
structured_output_payload = cast(dict[str, Any] | None, data)
if tool_name == terminal_tool_name:
terminal_tool_seen = True
elif tool_name == FINAL_OUTPUT_TOOL:
final_text = self._format_output_text(tool_args.get("text"))
if tool_name == terminal_tool_name:
terminal_tool_seen = True
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
if terminal_tool_seen:
terminal_output_seen = True
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
if tool_response == TERMINAL_OUTPUT_MESSAGE:
function_call_state = False
final_tool_args = tool_args
yield self._finish_log(
round_log,
data={
@ -196,31 +163,16 @@ class FunctionCallStrategy(AgentPattern):
# Return final result
from core.agent.entities import AgentResult
output_payload: str | AgentResult.StructuredOutput
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
output_text=None,
output_data=structured_output_payload,
)
elif final_text:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
output_text=final_text,
output_data=structured_output_payload,
)
elif output_text_payload:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.OUTPUT_TEXT,
output_text=str(output_text_payload),
output_data=None,
)
output_payload: str | dict
output_text = final_tool_args.get("text")
output_structured_payload = final_tool_args.get("data")
if isinstance(output_structured_payload, dict):
output_payload = output_structured_payload
elif isinstance(output_text, str):
output_payload = output_text
else:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
output_text="Model failed to produce a final output.",
output_data=None,
)
raise ValueError("Final output is not a string or structured data.")
return AgentResult(
output=output_payload,

View File

@ -4,9 +4,9 @@ from __future__ import annotations
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.agent.output_tools import (
FINAL_OUTPUT_TOOL,
@ -60,10 +60,10 @@ class ReActStrategy(AgentPattern):
def run(
self,
prompt_messages: list[PromptMessage],
prompt_messages:
list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
stop: list[str]
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
@ -137,16 +137,13 @@ class ReActStrategy(AgentPattern):
messages_to_use = current_messages
# Invoke model
chunks = cast(
Union[Generator[LLMResultChunk, None, None], LLMResult],
self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=False,
user=self.context.user_id or "",
callbacks=[],
),
chunks = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=False,
user=self.context.user_id or "",
callbacks=[],
)
# Process response
@ -173,7 +170,7 @@ class ReActStrategy(AgentPattern):
action_input={"raw": scratchpad.thought or ""},
)
scratchpad.action = illegal_action
scratchpad.action_str = json.dumps(illegal_action.to_dict())
scratchpad.action_str = illegal_action.model_dump_json()
react_state = True
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
scratchpad.observation = observation
@ -218,33 +215,9 @@ class ReActStrategy(AgentPattern):
# Return final result
from core.agent.entities import AgentResult
output_payload: str | dict
output_payload: str | AgentResult.StructuredOutput
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
output_text=None,
output_data=structured_output_payload,
)
elif final_text:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
output_text=final_text,
output_data=structured_output_payload,
)
elif output_text_payload:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.OUTPUT_TEXT,
output_text=str(output_text_payload),
output_data=structured_output_payload,
)
else:
output_payload = AgentResult.StructuredOutput(
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
output_text="Model failed to produce a final output.",
output_data=structured_output_payload,
)
# TODO
return AgentResult(
output=output_payload,
@ -336,7 +309,7 @@ class ReActStrategy(AgentPattern):
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
chunks: LLMResult,
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
@ -353,25 +326,20 @@ class ReActStrategy(AgentPattern):
"""
usage_dict: dict[str, Any] = {}
# Convert non-streaming to streaming format if needed
if isinstance(chunks, LLMResult):
# Create a generator from the LLMResult
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
streaming_chunks = result_to_chunks()
else:
streaming_chunks = chunks
streaming_chunks = result_to_chunks()
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
@ -397,7 +365,7 @@ class ReActStrategy(AgentPattern):
if emit_chunks:
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
else:
elif isinstance(chunk, str):
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
@ -405,6 +373,8 @@ class ReActStrategy(AgentPattern):
if emit_chunks:
yield self._create_text_chunk(chunk_text, current_messages)
else:
raise ValueError(f"Unexpected chunk type: {type(chunk)}")
# Update usage
if usage_dict.get("usage"):

View File

@ -2,13 +2,17 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentEntity, ExecutionContext
from core.file.models import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from ...app.entities.app_invoke_entities import InvokeFrom
from ...tools.entities.tool_entities import ToolInvokeFrom
from ..output_tools import build_agent_output_tools
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
@ -25,6 +29,10 @@ class StrategyFactory:
@staticmethod
def create_strategy(
*,
tenant_id: str,
invoke_from: InvokeFrom,
tool_invoke_from: ToolInvokeFrom,
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
@ -35,11 +43,16 @@ class StrategyFactory:
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
structured_output_schema: Mapping[str, Any] | None = None
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
tenant_id:
invoke_from:
tool_invoke_from:
structured_output_schema:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
@ -54,6 +67,15 @@ class StrategyFactory:
Returns:
AgentStrategy instance
"""
output_tools = build_agent_output_tools(
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
structured_output_schema=structured_output_schema
)
tools.extend(output_tools)
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:

View File

@ -12,13 +12,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext
from core.agent.output_tools import build_agent_output_tools, select_output_tool_names
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
from core.agent.output_tools import build_agent_output_tools
from core.agent.patterns import StrategyFactory
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app_assets.constants import AppAssetsAttrs
from core.file import File, FileTransferMethod, FileType, file_manager
from core.file import FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
@ -188,12 +188,11 @@ class LLMNode(Node[LLMNodeData]):
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
clean_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = "" # Initialize as empty string for consistency
usage: LLMUsage = LLMUsage.empty_usage()
finish_reason: str | None = None
reasoning_content: str = "" # Initialize as empty string for consistency
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
variable_pool = self.graph_runtime_state.variable_pool
variable_pool: VariablePool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
@ -253,8 +252,9 @@ class LLMNode(Node[LLMNodeData]):
)
query: str | None = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
memory_config = self.node_data.memory
if memory_config:
query = memory_config.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
@ -296,9 +296,16 @@ class LLMNode(Node[LLMNodeData]):
sandbox=self.graph_runtime_state.sandbox,
)
# Variables for outputs
generation_data: LLMGenerationData | None = None
structured_output: LLMStructuredOutput | None = None
structured_output_schema: Mapping[str, Any] | None
if self.node_data.structured_output_enabled:
if not self.node_data.structured_output:
raise ValueError("structured_output_enabled is True but structured_output is not set")
structured_output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self.node_data.structured_output
)
else:
structured_output_schema = None
if self.node_data.computer_use:
sandbox = self.graph_runtime_state.sandbox
@ -312,6 +319,7 @@ class LLMNode(Node[LLMNodeData]):
stop=stop,
variable_pool=variable_pool,
tool_dependencies=tool_dependencies,
structured_output_schema=structured_output_schema
)
elif self.tool_call_enabled:
generator = self._invoke_llm_with_tools(
@ -322,6 +330,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
structured_output_schema=structured_output_schema
)
else:
# Use traditional LLM invocation
@ -331,8 +340,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
structured_output_schema=structured_output_schema,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
@ -498,8 +506,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
structured_output_schema: Mapping[str, Any] | None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@ -513,10 +520,7 @@ class LLMNode(Node[LLMNodeData]):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
if structured_output_schema:
request_start_time = time.perf_counter()
invoke_result = invoke_llm_with_structured_output(
@ -524,7 +528,7 @@ class LLMNode(Node[LLMNodeData]):
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=output_schema,
json_schema=structured_output_schema,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
user=user_id,
@ -1876,6 +1880,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
structured_output_schema: Mapping[str, Any] | None
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Invoke LLM with tools support (from Agent V2).
@ -1892,20 +1897,23 @@ class LLMNode(Node[LLMNodeData]):
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
model_features=model_features,
model_instance=model_instance,
tools=tool_instances,
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema
)
# Run strategy
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=False,
stop=list(stop or [])
)
result = yield from self._process_tool_outputs(outputs)
@ -1919,6 +1927,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None,
variable_pool: VariablePool,
tool_dependencies: ToolDependencies | None,
structured_output_schema: Mapping[str, Any] | None
) -> Generator[NodeEventBase, None, LLMGenerationData]:
result: LLMGenerationData | None = None
sandbox_output_files: list[File] = []
@ -1927,37 +1936,25 @@ class LLMNode(Node[LLMNodeData]):
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
prompt_files = self._extract_prompt_files(variable_pool)
model_features = self._get_model_features(model_instance)
structured_output_schema = None
if self._node_data.structured_output_enabled:
structured_output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
)
output_tools = build_agent_output_tools(
strategy = StrategyFactory.create_strategy(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
output_tool_names=select_output_tool_names(
structured_output_enabled=self._node_data.structured_output_enabled,
include_illegal_output=True,
),
structured_output_schema=structured_output_schema,
)
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=[session.bash_tool, *output_tools],
tools=[session.bash_tool],
files=prompt_files,
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=False,
stop=list(stop or [])
)
result = yield from self._process_tool_outputs(outputs)
@ -2055,16 +2052,9 @@ class LLMNode(Node[LLMNodeData]):
structured_output=self._node_data.structured_output or {},
)
tool_instances.extend(
build_agent_output_tools(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
output_tool_names=select_output_tool_names(
structured_output_enabled=self._node_data.structured_output_enabled,
include_illegal_output=True,
),
structured_output_schema=structured_output_schema,
)
build_agent_output_tools(tenant_id=self.tenant_id, invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
structured_output_schema=structured_output_schema)
)
return tool_instances
@ -2485,6 +2475,7 @@ class LLMNode(Node[LLMNodeData]):
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_state.trace_segments:
# FIXME: These if will never happen
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
@ -2564,32 +2555,25 @@ class LLMNode(Node[LLMNodeData]):
if not isinstance(exception.value, AgentResult):
raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception
state.agent.agent_result = exception.value
if not state.agent.agent_result:
agent_result = state.agent.agent_result
if not agent_result:
raise ValueError("No agent result found in tool outputs")
output_payload = state.agent.agent_result.output
structured_output_data: Mapping[str, Any] | None = None
if isinstance(output_payload, AgentResult.StructuredOutput):
output_kind = output_payload.output_kind
if output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
raise ValueError("Agent returned illegal output")
if output_kind in {AgentOutputKind.FINAL_OUTPUT_ANSWER, AgentOutputKind.OUTPUT_TEXT}:
if not output_payload.output_text:
raise ValueError("Agent returned empty text output")
state.aggregate.text = output_payload.output_text
elif output_kind == AgentOutputKind.FINAL_STRUCTURED_OUTPUT:
if output_payload.output_data is None:
raise ValueError("Agent returned empty structured output")
else:
raise ValueError("Agent returned unsupported output kind")
if output_payload.output_data is not None:
if not isinstance(output_payload.output_data, Mapping):
raise ValueError("Agent returned invalid structured output")
structured_output_data = output_payload.output_data
output_payload = agent_result.output
if isinstance(output_payload, dict):
state.aggregate.structured_output = LLMStructuredOutput(
structured_output=convert_file_refs_in_output(
output=output_payload,
json_schema=LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
),
tenant_id=self.tenant_id,
)
)
state.aggregate.text = json.dumps(output_payload)
elif isinstance(output_payload, str):
state.aggregate.text = output_payload
else:
if not output_payload:
raise ValueError("Agent returned empty output")
state.aggregate.text = str(output_payload)
raise ValueError(f"Unexpected output type: {type(output_payload)}")
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
@ -2597,17 +2581,6 @@ class LLMNode(Node[LLMNodeData]):
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
if structured_output_data is not None:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
)
converted_output = convert_file_refs_in_output(
output=structured_output_data,
json_schema=output_schema,
tenant_id=self.tenant_id,
)
state.aggregate.structured_output = LLMStructuredOutput(structured_output=converted_output)
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()

View File

@ -1,4 +1,3 @@
import json
from collections.abc import Generator
from core.agent.entities import AgentScratchpadUnit
@ -64,7 +63,7 @@ def test_cot_output_parser():
output += result
elif isinstance(result, AgentScratchpadUnit.Action):
if test_case["action"]:
assert result.to_dict() == test_case["action"]
output += json.dumps(result.to_dict())
assert result.model_dump() == test_case["action"]
output += result.model_dump_json()
if test_case["output"]:
assert output == test_case["output"]

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
class ConcreteAgentPattern(AgentPattern):
"""Concrete implementation of AgentPattern for testing."""
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
def run(self, prompt_messages, model_parameters, stop=[]):
"""Minimal implementation for testing."""
yield from []

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentPromptEntity, AgentResult
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.llm_entities import LLMUsage
@ -329,20 +329,15 @@ class TestAgentLogProcessing:
)
result = AgentResult(
output=AgentResult.StructuredOutput(
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
output_text="Final answer",
output_data=None,
),
output="Final answer",
files=[],
usage=usage,
finish_reason="stop",
)
output_payload = result.output
assert isinstance(output_payload, AgentResult.StructuredOutput)
assert output_payload.output_text == "Final answer"
assert output_payload.output_kind == AgentOutputKind.FINAL_OUTPUT_ANSWER
assert isinstance(output_payload, str)
assert output_payload == "Final answer"
assert result.files == []
assert result.usage == usage
assert result.finish_reason == "stop"

View File

@ -153,7 +153,7 @@ class TestAgentScratchpadUnit:
action_input={"query": "test"},
)
result = action.to_dict()
result = action.model_dump()
assert result == {
"action": "search",