mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
refactor: remove union types
Signed-off-by: Stream <Stream_2@qq.com>
This commit is contained in:
parent
a87560d667
commit
9ad49340bf
@ -1,10 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
@ -13,7 +14,6 @@ from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
@ -118,7 +118,6 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
@ -256,49 +255,30 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
output_payload = result.output
|
||||
if isinstance(output_payload, AgentResult.StructuredOutput):
|
||||
if output_payload.output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
|
||||
raise ValueError("Agent returned illegal output")
|
||||
if output_payload.output_kind not in {
|
||||
AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
AgentOutputKind.OUTPUT_TEXT,
|
||||
}:
|
||||
raise ValueError("Agent did not return text output")
|
||||
if not output_payload.output_text:
|
||||
raise ValueError("Agent returned empty text output")
|
||||
final_answer = output_payload.output_text
|
||||
else:
|
||||
if not output_payload:
|
||||
raise ValueError("Agent returned empty output")
|
||||
final_answer = str(output_payload)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
if not isinstance(result, AgentResult):
|
||||
raise ValueError("Agent did not return AgentResult")
|
||||
output_payload = result.output
|
||||
if isinstance(output_payload, dict):
|
||||
final_answer = json.dumps(output_payload, ensure_ascii=False)
|
||||
elif isinstance(output_payload, str):
|
||||
final_answer = output_payload
|
||||
else:
|
||||
raise ValueError("Final output is not a string or structured data.")
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
if False:
|
||||
yield LLMResultChunk(
|
||||
model="",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=None,
|
||||
),
|
||||
)
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Union, cast
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools, select_output_tool_names
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -253,15 +253,9 @@ class BaseAgentRunner(AppRunner):
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
output_tool_names=select_output_tool_names(
|
||||
structured_output_enabled=False,
|
||||
include_illegal_output=True,
|
||||
),
|
||||
)
|
||||
output_tools = build_agent_output_tools(tenant_id=self.tenant_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT)
|
||||
for tool in output_tools:
|
||||
tool_instances[tool.entity.identity.name] = tool
|
||||
|
||||
|
||||
@ -190,24 +190,11 @@ class AgentOutputKind(StrEnum):
|
||||
ILLEGAL_OUTPUT = "illegal_output"
|
||||
|
||||
|
||||
OutputKind = AgentOutputKind
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
class StructuredOutput(BaseModel):
|
||||
"""
|
||||
Structured output payload from output tools.
|
||||
"""
|
||||
|
||||
output_kind: AgentOutputKind
|
||||
output_text: str | None = None
|
||||
output_data: Mapping[str, Any] | None = None
|
||||
|
||||
output: str | StructuredOutput = Field(default="", description="The generated output")
|
||||
output: str | dict = Field(default="", description="The generated output")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
OUTPUT_TOOL_PROVIDER = "agent_output"
|
||||
@ -22,63 +23,87 @@ OUTPUT_TOOL_NAMES: Sequence[str] = (
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
)
|
||||
|
||||
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
|
||||
TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL)
|
||||
|
||||
|
||||
def select_output_tool_names(
|
||||
*,
|
||||
structured_output_enabled: bool,
|
||||
include_illegal_output: bool = False,
|
||||
) -> list[str]:
|
||||
tool_names = [OUTPUT_TEXT_TOOL]
|
||||
if structured_output_enabled:
|
||||
tool_names.append(FINAL_STRUCTURED_OUTPUT_TOOL)
|
||||
else:
|
||||
tool_names.append(FINAL_OUTPUT_TOOL)
|
||||
if include_illegal_output:
|
||||
tool_names.append(ILLEGAL_OUTPUT_TOOL)
|
||||
return tool_names
|
||||
TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session."
|
||||
|
||||
|
||||
def select_terminal_tool_name(*, structured_output_enabled: bool) -> str:
|
||||
def is_terminal_output_tool(tool_name: str) -> bool:
|
||||
return tool_name in TERMINAL_OUTPUT_TOOL_NAMES
|
||||
|
||||
|
||||
def get_terminal_tool_name(structured_output_enabled: bool) -> str:
|
||||
return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL
|
||||
|
||||
|
||||
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
|
||||
|
||||
|
||||
def build_agent_output_tools(
|
||||
*,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_invoke_from: ToolInvokeFrom,
|
||||
output_tool_names: Sequence[str],
|
||||
structured_output_schema: dict[str, Any] | None = None,
|
||||
structured_output_schema: Mapping[str, Any] | None = None,
|
||||
) -> list[Tool]:
|
||||
tools: list[Tool] = []
|
||||
tool_names: list[str] = []
|
||||
for tool_name in output_tool_names:
|
||||
if tool_name not in OUTPUT_TOOL_NAME_SET:
|
||||
raise ValueError(f"Unknown output tool name: {tool_name}")
|
||||
if tool_name not in tool_names:
|
||||
tool_names.append(tool_name)
|
||||
|
||||
for tool_name in tool_names:
|
||||
tool = ToolManager.get_tool_runtime(
|
||||
def get_tool_runtime(_tool_name: str) -> Tool:
|
||||
return ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id=OUTPUT_TOOL_PROVIDER,
|
||||
tool_name=tool_name,
|
||||
tool_name=_tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
|
||||
if tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and structured_output_schema:
|
||||
tool.entity = tool.entity.model_copy(deep=True)
|
||||
for parameter in tool.entity.parameters:
|
||||
if parameter.name != "data":
|
||||
continue
|
||||
parameter.type = ToolParameter.ToolParameterType.OBJECT
|
||||
parameter.form = ToolParameter.ToolParameterForm.LLM
|
||||
parameter.required = True
|
||||
parameter.input_schema = structured_output_schema
|
||||
tools.append(tool)
|
||||
tools: list[Tool] = [
|
||||
get_tool_runtime(OUTPUT_TEXT_TOOL),
|
||||
get_tool_runtime(ILLEGAL_OUTPUT_TOOL),
|
||||
]
|
||||
|
||||
if structured_output_schema:
|
||||
raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
|
||||
raw_tool.entity = raw_tool.entity.model_copy(deep=True)
|
||||
data_parameter = ToolParameter(
|
||||
name="data",
|
||||
type=ToolParameter.ToolParameterType.OBJECT,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
required=True,
|
||||
input_schema=dict(structured_output_schema),
|
||||
label=I18nObject(en_US="__Data", zh_Hans="__Data"),
|
||||
)
|
||||
raw_tool.entity.parameters = [data_parameter]
|
||||
|
||||
def invoke_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> ToolInvokeMessage:
|
||||
data = tool_parameters["data"]
|
||||
if not data:
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required"))
|
||||
if not isinstance(data, dict):
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict"))
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
|
||||
|
||||
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
|
||||
tools.append(raw_tool)
|
||||
else:
|
||||
raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
|
||||
|
||||
def invoke_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
|
||||
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
|
||||
tools.append(raw_tool)
|
||||
|
||||
return tools
|
||||
|
||||
@ -60,8 +60,7 @@ class AgentPattern(ABC):
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
stop: list[str]
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
@ -5,13 +5,10 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Protocol, Union, cast
|
||||
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
TERMINAL_OUTPUT_MESSAGE,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
@ -36,21 +33,11 @@ class FunctionCallStrategy(AgentPattern):
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
stop: list[str]
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
tool_instance_names = {tool.entity.identity.name for tool in self.tools}
|
||||
available_output_tool_names = {tool.name for tool in prompt_tools if tool.name in OUTPUT_TOOL_NAME_SET}
|
||||
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
||||
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_OUTPUT_TOOL
|
||||
else:
|
||||
raise ValueError("No terminal output tool configured")
|
||||
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
@ -60,10 +47,10 @@ class FunctionCallStrategy(AgentPattern):
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
structured_output_payload: dict[str, Any] | None = None
|
||||
final_tool_args: dict[str, Any] = {"!!!": "!!!"}
|
||||
output_text_payload: str | None = None
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
terminal_output_seen = False
|
||||
|
||||
class _LLMInvoker(Protocol):
|
||||
def invoke_llm(
|
||||
@ -87,11 +74,7 @@ class FunctionCallStrategy(AgentPattern):
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, restrict tools to output tools
|
||||
if iteration_step == max_iterations:
|
||||
current_tools = [tool for tool in prompt_tools if tool.name in available_output_tool_names]
|
||||
else:
|
||||
current_tools = prompt_tools
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
@ -112,7 +95,7 @@ class FunctionCallStrategy(AgentPattern):
|
||||
chunks = invoker.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
tools=prompt_tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.context.user_id,
|
||||
@ -124,19 +107,15 @@ class FunctionCallStrategy(AgentPattern):
|
||||
chunks, round_usage, model_log, emit_chunks=False
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
if not allow_illegal_output:
|
||||
raise ValueError("Model did not call any tools")
|
||||
tool_calls = [
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
]
|
||||
response_content = ""
|
||||
if response_content:
|
||||
replaced_tool_call = (
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
tool_calls.append(replaced_tool_call)
|
||||
|
||||
messages.append(self._create_assistant_message("", tool_calls))
|
||||
|
||||
@ -149,35 +128,23 @@ class FunctionCallStrategy(AgentPattern):
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
terminal_tool_seen = False
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
if tool_name == OUTPUT_TEXT_TOOL:
|
||||
output_text_payload = self._format_output_text(tool_args.get("text"))
|
||||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
||||
data = tool_args.get("data")
|
||||
structured_output_payload = cast(dict[str, Any] | None, data)
|
||||
if tool_name == terminal_tool_name:
|
||||
terminal_tool_seen = True
|
||||
elif tool_name == FINAL_OUTPUT_TOOL:
|
||||
final_text = self._format_output_text(tool_args.get("text"))
|
||||
if tool_name == terminal_tool_name:
|
||||
terminal_tool_seen = True
|
||||
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
if terminal_tool_seen:
|
||||
terminal_output_seen = True
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
if tool_response == TERMINAL_OUTPUT_MESSAGE:
|
||||
function_call_state = False
|
||||
final_tool_args = tool_args
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
@ -196,31 +163,16 @@ class FunctionCallStrategy(AgentPattern):
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
|
||||
output_text=None,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=None,
|
||||
)
|
||||
output_payload: str | dict
|
||||
output_text = final_tool_args.get("text")
|
||||
output_structured_payload = final_tool_args.get("data")
|
||||
|
||||
if isinstance(output_structured_payload, dict):
|
||||
output_payload = output_structured_payload
|
||||
elif isinstance(output_text, str):
|
||||
output_payload = output_text
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=None,
|
||||
)
|
||||
raise ValueError("Final output is not a string or structured data.")
|
||||
|
||||
return AgentResult(
|
||||
output=output_payload,
|
||||
|
||||
@ -4,9 +4,9 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentOutputKind, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
@ -60,10 +60,10 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages:
|
||||
list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
stop: list[str]
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
@ -137,16 +137,13 @@ class ReActStrategy(AgentPattern):
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks = cast(
|
||||
Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
),
|
||||
chunks = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
@ -173,7 +170,7 @@ class ReActStrategy(AgentPattern):
|
||||
action_input={"raw": scratchpad.thought or ""},
|
||||
)
|
||||
scratchpad.action = illegal_action
|
||||
scratchpad.action_str = json.dumps(illegal_action.to_dict())
|
||||
scratchpad.action_str = illegal_action.model_dump_json()
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
|
||||
scratchpad.observation = observation
|
||||
@ -218,33 +215,9 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
output_payload: str | dict
|
||||
|
||||
output_payload: str | AgentResult.StructuredOutput
|
||||
if terminal_tool_name == FINAL_STRUCTURED_OUTPUT_TOOL and terminal_output_seen:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_STRUCTURED_OUTPUT,
|
||||
output_text=None,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif final_text:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text=final_text,
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
elif output_text_payload:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.OUTPUT_TEXT,
|
||||
output_text=str(output_text_payload),
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
else:
|
||||
output_payload = AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.ILLEGAL_OUTPUT,
|
||||
output_text="Model failed to produce a final output.",
|
||||
output_data=structured_output_payload,
|
||||
)
|
||||
# TODO
|
||||
|
||||
return AgentResult(
|
||||
output=output_payload,
|
||||
@ -336,7 +309,7 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
chunks: LLMResult,
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
@ -353,25 +326,20 @@ class ReActStrategy(AgentPattern):
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
# Create a generator from the LLMResult
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
streaming_chunks = result_to_chunks()
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
@ -397,7 +365,7 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
elif isinstance(chunk, str):
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
@ -405,6 +373,8 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
else:
|
||||
raise ValueError(f"Unexpected chunk type: {type(chunk)}")
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
|
||||
@ -2,13 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.file.models import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from ...app.entities.app_invoke_entities import InvokeFrom
|
||||
from ...tools.entities.tool_entities import ToolInvokeFrom
|
||||
from ..output_tools import build_agent_output_tools
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
@ -25,6 +29,10 @@ class StrategyFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
*,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_invoke_from: ToolInvokeFrom,
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
@ -35,11 +43,16 @@ class StrategyFactory:
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
structured_output_schema: Mapping[str, Any] | None = None
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
tenant_id:
|
||||
invoke_from:
|
||||
tool_invoke_from:
|
||||
structured_output_schema:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
@ -54,6 +67,15 @@ class StrategyFactory:
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
structured_output_schema=structured_output_schema
|
||||
)
|
||||
|
||||
tools.extend(output_tools)
|
||||
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
|
||||
@ -12,13 +12,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools, select_output_tool_names
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app_assets.constants import AppAssetsAttrs
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.file import FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
@ -188,12 +188,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
clean_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = "" # Initialize as empty string for consistency
|
||||
usage: LLMUsage = LLMUsage.empty_usage()
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str = "" # Initialize as empty string for consistency
|
||||
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
variable_pool: VariablePool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# Parse prompt template to separate static messages and context references
|
||||
@ -253,8 +252,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
memory_config = self.node_data.memory
|
||||
if memory_config:
|
||||
query = memory_config.query_prompt_template
|
||||
if not query and (
|
||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
):
|
||||
@ -296,9 +296,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sandbox=self.graph_runtime_state.sandbox,
|
||||
)
|
||||
|
||||
# Variables for outputs
|
||||
generation_data: LLMGenerationData | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
structured_output_schema: Mapping[str, Any] | None
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not self.node_data.structured_output:
|
||||
raise ValueError("structured_output_enabled is True but structured_output is not set")
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self.node_data.structured_output
|
||||
)
|
||||
else:
|
||||
structured_output_schema = None
|
||||
|
||||
if self.node_data.computer_use:
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
@ -312,6 +319,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=stop,
|
||||
variable_pool=variable_pool,
|
||||
tool_dependencies=tool_dependencies,
|
||||
structured_output_schema=structured_output_schema
|
||||
)
|
||||
elif self.tool_call_enabled:
|
||||
generator = self._invoke_llm_with_tools(
|
||||
@ -322,6 +330,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
structured_output_schema=structured_output_schema
|
||||
)
|
||||
else:
|
||||
# Use traditional LLM invocation
|
||||
@ -331,8 +340,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
structured_output_schema=structured_output_schema,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
@ -498,8 +506,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
structured_output_schema: Mapping[str, Any] | None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
@ -513,10 +520,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
if structured_output_schema:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
@ -524,7 +528,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
json_schema=structured_output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
user=user_id,
|
||||
@ -1876,6 +1880,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
node_inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
structured_output_schema: Mapping[str, Any] | None
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Invoke LLM with tools support (from Agent V2).
|
||||
|
||||
@ -1892,20 +1897,23 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=tool_instances,
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 10,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
structured_output_schema=structured_output_schema
|
||||
)
|
||||
|
||||
# Run strategy
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=False,
|
||||
stop=list(stop or [])
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
@ -1919,6 +1927,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop: Sequence[str] | None,
|
||||
variable_pool: VariablePool,
|
||||
tool_dependencies: ToolDependencies | None,
|
||||
structured_output_schema: Mapping[str, Any] | None
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
result: LLMGenerationData | None = None
|
||||
sandbox_output_files: list[File] = []
|
||||
@ -1927,37 +1936,25 @@ class LLMNode(Node[LLMNodeData]):
|
||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
model_features = self._get_model_features(model_instance)
|
||||
structured_output_schema = None
|
||||
if self._node_data.structured_output_enabled:
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
output_tools = build_agent_output_tools(
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
output_tool_names=select_output_tool_names(
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
include_illegal_output=True,
|
||||
),
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=[session.bash_tool, *output_tools],
|
||||
tools=[session.bash_tool],
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 100,
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
structured_output_schema=structured_output_schema
|
||||
)
|
||||
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=False,
|
||||
stop=list(stop or [])
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
@ -2055,16 +2052,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
tool_instances.extend(
|
||||
build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
output_tool_names=select_output_tool_names(
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
include_illegal_output=True,
|
||||
),
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
build_agent_output_tools(tenant_id=self.tenant_id, invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema)
|
||||
)
|
||||
|
||||
return tool_instances
|
||||
@ -2485,6 +2475,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
content_position = 0
|
||||
tool_call_seen_index: dict[str, int] = {}
|
||||
for trace_segment in trace_state.trace_segments:
|
||||
# FIXME: These if will never happen
|
||||
if trace_segment.type == "thought":
|
||||
sequence.append({"type": "reasoning", "index": reasoning_index})
|
||||
reasoning_index += 1
|
||||
@ -2564,32 +2555,25 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if not isinstance(exception.value, AgentResult):
|
||||
raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception
|
||||
state.agent.agent_result = exception.value
|
||||
if not state.agent.agent_result:
|
||||
agent_result = state.agent.agent_result
|
||||
if not agent_result:
|
||||
raise ValueError("No agent result found in tool outputs")
|
||||
output_payload = state.agent.agent_result.output
|
||||
structured_output_data: Mapping[str, Any] | None = None
|
||||
if isinstance(output_payload, AgentResult.StructuredOutput):
|
||||
output_kind = output_payload.output_kind
|
||||
if output_kind == AgentOutputKind.ILLEGAL_OUTPUT:
|
||||
raise ValueError("Agent returned illegal output")
|
||||
if output_kind in {AgentOutputKind.FINAL_OUTPUT_ANSWER, AgentOutputKind.OUTPUT_TEXT}:
|
||||
if not output_payload.output_text:
|
||||
raise ValueError("Agent returned empty text output")
|
||||
state.aggregate.text = output_payload.output_text
|
||||
elif output_kind == AgentOutputKind.FINAL_STRUCTURED_OUTPUT:
|
||||
if output_payload.output_data is None:
|
||||
raise ValueError("Agent returned empty structured output")
|
||||
else:
|
||||
raise ValueError("Agent returned unsupported output kind")
|
||||
|
||||
if output_payload.output_data is not None:
|
||||
if not isinstance(output_payload.output_data, Mapping):
|
||||
raise ValueError("Agent returned invalid structured output")
|
||||
structured_output_data = output_payload.output_data
|
||||
output_payload = agent_result.output
|
||||
if isinstance(output_payload, dict):
|
||||
state.aggregate.structured_output = LLMStructuredOutput(
|
||||
structured_output=convert_file_refs_in_output(
|
||||
output=output_payload,
|
||||
json_schema=LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
),
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
state.aggregate.text = json.dumps(output_payload)
|
||||
elif isinstance(output_payload, str):
|
||||
state.aggregate.text = output_payload
|
||||
else:
|
||||
if not output_payload:
|
||||
raise ValueError("Agent returned empty output")
|
||||
state.aggregate.text = str(output_payload)
|
||||
raise ValueError(f"Unexpected output type: {type(output_payload)}")
|
||||
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
@ -2597,17 +2581,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
|
||||
if structured_output_data is not None:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
converted_output = convert_file_refs_in_output(
|
||||
output=structured_output_data,
|
||||
json_schema=output_schema,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
state.aggregate.structured_output = LLMStructuredOutput(structured_output=converted_output)
|
||||
|
||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||
yield from self._close_streams()
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
@ -64,7 +63,7 @@ def test_cot_output_parser():
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
assert result.model_dump() == test_case["action"]
|
||||
output += result.model_dump_json()
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
||||
|
||||
@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
class ConcreteAgentPattern(AgentPattern):
|
||||
"""Concrete implementation of AgentPattern for testing."""
|
||||
|
||||
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
|
||||
def run(self, prompt_messages, model_parameters, stop=[]):
|
||||
"""Minimal implementation for testing."""
|
||||
yield from []
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentOutputKind, AgentPromptEntity, AgentResult
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
@ -329,20 +329,15 @@ class TestAgentLogProcessing:
|
||||
)
|
||||
|
||||
result = AgentResult(
|
||||
output=AgentResult.StructuredOutput(
|
||||
output_kind=AgentOutputKind.FINAL_OUTPUT_ANSWER,
|
||||
output_text="Final answer",
|
||||
output_data=None,
|
||||
),
|
||||
output="Final answer",
|
||||
files=[],
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
output_payload = result.output
|
||||
assert isinstance(output_payload, AgentResult.StructuredOutput)
|
||||
assert output_payload.output_text == "Final answer"
|
||||
assert output_payload.output_kind == AgentOutputKind.FINAL_OUTPUT_ANSWER
|
||||
assert isinstance(output_payload, str)
|
||||
assert output_payload == "Final answer"
|
||||
assert result.files == []
|
||||
assert result.usage == usage
|
||||
assert result.finish_reason == "stop"
|
||||
|
||||
@ -153,7 +153,7 @@ class TestAgentScratchpadUnit:
|
||||
action_input={"query": "test"},
|
||||
)
|
||||
|
||||
result = action.to_dict()
|
||||
result = action.model_dump()
|
||||
|
||||
assert result == {
|
||||
"action": "search",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user