mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
revert: add tools for output in agent mode
feat: hide output tools and improve JSON formatting for structured output feat: hide output tools and improve JSON formatting for structured output fix: handle prompt template correctly to extract selectors for step run fix: emit StreamChunkEvent correctly for sandbox agent chore: better debug message fix: incorrect output tool runtime selection fix: type issues fix: align parameter list fix: align parameter list fix: hide internal builtin providers from tool list vibe: implement file structured output vibe: implement file structured output fix: refix parameter for tool fix: crash fix: crash refactor: remove union types fix: type check Merge branch 'feat/structured-output-with-sandbox' into feat/support-agent-sandbox fix: provide json as text fix: provide json as text fix: get AgentResult correctly fix: provides correct prompts, tools and terminal predicates fix: provides correct prompts, tools and terminal predicates fix: circular import feat: support structured output in sandbox and tool mode
This commit is contained in:
parent
25065a4f2f
commit
e0082dbf18
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -24,7 +23,7 @@ from core.model_runtime.entities import (
|
|||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from models.model import Message
|
from models.model import Message
|
||||||
|
|
||||||
@ -103,13 +102,11 @@ class AgentAppRunner(BaseAgentRunner):
|
|||||||
agent_strategy=self.config.strategy,
|
agent_strategy=self.config.strategy,
|
||||||
tool_invoke_hook=tool_invoke_hook,
|
tool_invoke_hook=tool_invoke_hook,
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
|
||||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize state variables
|
# Initialize state variables
|
||||||
current_agent_thought_id = None
|
current_agent_thought_id = None
|
||||||
|
has_published_thought = False
|
||||||
current_tool_name: str | None = None
|
current_tool_name: str | None = None
|
||||||
self._current_message_file_ids: list[str] = []
|
self._current_message_file_ids: list[str] = []
|
||||||
|
|
||||||
@ -121,6 +118,7 @@ class AgentAppRunner(BaseAgentRunner):
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
model_parameters=app_generate_entity.model_conf.parameters,
|
model_parameters=app_generate_entity.model_conf.parameters,
|
||||||
stop=app_generate_entity.model_conf.stop,
|
stop=app_generate_entity.model_conf.stop,
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Consume generator and collect result
|
# Consume generator and collect result
|
||||||
@ -135,10 +133,17 @@ class AgentAppRunner(BaseAgentRunner):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if isinstance(output, LLMResultChunk):
|
if isinstance(output, LLMResultChunk):
|
||||||
# No more expect streaming data
|
# Handle LLM chunk
|
||||||
continue
|
if current_agent_thought_id and not has_published_thought:
|
||||||
|
self.queue_manager.publish(
|
||||||
|
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
|
)
|
||||||
|
has_published_thought = True
|
||||||
|
|
||||||
else:
|
yield output
|
||||||
|
|
||||||
|
elif isinstance(output, AgentLog):
|
||||||
# Handle Agent Log using log_type for type-safe dispatch
|
# Handle Agent Log using log_type for type-safe dispatch
|
||||||
if output.status == AgentLog.LogStatus.START:
|
if output.status == AgentLog.LogStatus.START:
|
||||||
if output.log_type == AgentLog.LogType.ROUND:
|
if output.log_type == AgentLog.LogType.ROUND:
|
||||||
@ -151,6 +156,7 @@ class AgentAppRunner(BaseAgentRunner):
|
|||||||
tool_input="",
|
tool_input="",
|
||||||
messages_ids=message_file_ids,
|
messages_ids=message_file_ids,
|
||||||
)
|
)
|
||||||
|
has_published_thought = False
|
||||||
|
|
||||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||||
if current_agent_thought_id is None:
|
if current_agent_thought_id is None:
|
||||||
@ -258,30 +264,23 @@ class AgentAppRunner(BaseAgentRunner):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# Process final result
|
# Process final result
|
||||||
if not isinstance(result, AgentResult):
|
if isinstance(result, AgentResult):
|
||||||
raise ValueError("Agent did not return AgentResult")
|
final_answer = result.text
|
||||||
output_payload = result.output
|
usage = result.usage or LLMUsage.empty_usage()
|
||||||
if isinstance(output_payload, dict):
|
|
||||||
final_answer = json.dumps(output_payload, ensure_ascii=False)
|
|
||||||
elif isinstance(output_payload, str):
|
|
||||||
final_answer = output_payload
|
|
||||||
else:
|
|
||||||
raise ValueError("Final output is not a string or structured data.")
|
|
||||||
usage = result.usage or LLMUsage.empty_usage()
|
|
||||||
|
|
||||||
# Publish end event
|
# Publish end event
|
||||||
self.queue_manager.publish(
|
self.queue_manager.publish(
|
||||||
QueueMessageEndEvent(
|
QueueMessageEndEvent(
|
||||||
llm_result=LLMResult(
|
llm_result=LLMResult(
|
||||||
model=self.model_instance.model,
|
model=self.model_instance.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=final_answer),
|
message=AssistantPromptMessage(content=final_answer),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
system_fingerprint="",
|
system_fingerprint="",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
PublishFrom.APPLICATION_MANAGER,
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from typing import Union, cast
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||||
from core.agent.output_tools import build_agent_output_tools
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
@ -37,7 +36,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolInvokeFrom,
|
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
)
|
)
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
@ -253,14 +251,6 @@ class BaseAgentRunner(AppRunner):
|
|||||||
# save tool entity
|
# save tool entity
|
||||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||||
|
|
||||||
output_tools = build_agent_output_tools(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
|
||||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
|
||||||
)
|
|
||||||
for tool in output_tools:
|
|
||||||
tool_instances[tool.entity.identity.name] = tool
|
|
||||||
|
|
||||||
return tool_instances, prompt_messages_tools
|
return tool_instances, prompt_messages_tools
|
||||||
|
|
||||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.agent.output_tools import FINAL_OUTPUT_TOOL
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||||
|
|
||||||
|
|
||||||
@ -42,9 +41,9 @@ class AgentScratchpadUnit(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
action_name: str
|
action_name: str
|
||||||
action_input: dict[str, Any] | str
|
action_input: Union[dict, str]
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Convert to dictionary.
|
Convert to dictionary.
|
||||||
"""
|
"""
|
||||||
@ -63,9 +62,9 @@ class AgentScratchpadUnit(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Check if the scratchpad unit is final.
|
Check if the scratchpad unit is final.
|
||||||
"""
|
"""
|
||||||
if self.action is None:
|
return self.action is None or (
|
||||||
return False
|
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||||
return self.action.action_name.lower() == FINAL_OUTPUT_TOOL
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentEntity(BaseModel):
|
class AgentEntity(BaseModel):
|
||||||
@ -126,7 +125,7 @@ class ExecutionContext(BaseModel):
|
|||||||
"tenant_id": self.tenant_id,
|
"tenant_id": self.tenant_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_updates(self, **kwargs: Any) -> "ExecutionContext":
|
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||||
"""Create a new context with updated fields."""
|
"""Create a new context with updated fields."""
|
||||||
data = self.to_dict()
|
data = self.to_dict()
|
||||||
data.update(kwargs)
|
data.update(kwargs)
|
||||||
@ -179,23 +178,12 @@ class AgentLog(BaseModel):
|
|||||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||||
|
|
||||||
|
|
||||||
class AgentOutputKind(StrEnum):
|
|
||||||
"""
|
|
||||||
Agent output kind.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_TEXT = "output_text"
|
|
||||||
FINAL_OUTPUT_ANSWER = "final_output_answer"
|
|
||||||
FINAL_STRUCTURED_OUTPUT = "final_structured_output"
|
|
||||||
ILLEGAL_OUTPUT = "illegal_output"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentResult(BaseModel):
|
class AgentResult(BaseModel):
|
||||||
"""
|
"""
|
||||||
Agent execution result.
|
Agent execution result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output: str | dict = Field(default="", description="The generated output")
|
text: str = Field(default="", description="The generated text")
|
||||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Union, cast
|
from typing import Union
|
||||||
|
|
||||||
from core.agent.entities import AgentScratchpadUnit
|
from core.agent.entities import AgentScratchpadUnit
|
||||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||||
@ -10,52 +10,46 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
|||||||
class CotAgentOutputParser:
|
class CotAgentOutputParser:
|
||||||
@classmethod
|
@classmethod
|
||||||
def handle_react_stream_output(
|
def handle_react_stream_output(
|
||||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||||
def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]:
|
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||||
action_name: str | None = None
|
action_name = None
|
||||||
action_input: Any | None = None
|
action_input = None
|
||||||
parsed_action: Any = action
|
if isinstance(action, str):
|
||||||
if isinstance(parsed_action, str):
|
|
||||||
try:
|
try:
|
||||||
parsed_action = json.loads(parsed_action, strict=False)
|
action = json.loads(action, strict=False)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return parsed_action or ""
|
return action or ""
|
||||||
|
|
||||||
# cohere always returns a list
|
# cohere always returns a list
|
||||||
if isinstance(parsed_action, list):
|
if isinstance(action, list) and len(action) == 1:
|
||||||
action_list: list[Any] = cast(list[Any], parsed_action)
|
action = action[0]
|
||||||
if len(action_list) == 1:
|
|
||||||
parsed_action = action_list[0]
|
|
||||||
|
|
||||||
if isinstance(parsed_action, dict):
|
for key, value in action.items():
|
||||||
action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action)
|
if "input" in key.lower():
|
||||||
for key, value in action_dict.items():
|
action_input = value
|
||||||
if "input" in key.lower():
|
else:
|
||||||
action_input = value
|
action_name = value
|
||||||
elif isinstance(value, str):
|
|
||||||
action_name = value
|
|
||||||
else:
|
|
||||||
return json.dumps(parsed_action)
|
|
||||||
|
|
||||||
if action_name is not None and action_input is not None:
|
if action_name is not None and action_input is not None:
|
||||||
return AgentScratchpadUnit.Action(
|
return AgentScratchpadUnit.Action(
|
||||||
action_name=action_name,
|
action_name=action_name,
|
||||||
action_input=action_input,
|
action_input=action_input,
|
||||||
)
|
)
|
||||||
return json.dumps(parsed_action)
|
else:
|
||||||
|
return json.dumps(action)
|
||||||
|
|
||||||
def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]:
|
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
|
||||||
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
|
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
json_blocks: list[dict[str, Any] | list[Any]] = []
|
json_blocks = []
|
||||||
for block in blocks:
|
for block in blocks:
|
||||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||||
json_blocks.append(json.loads(json_text, strict=False))
|
json_blocks.append(json.loads(json_text, strict=False))
|
||||||
return json_blocks
|
return json_blocks
|
||||||
except Exception:
|
except:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
code_block_cache = ""
|
code_block_cache = ""
|
||||||
|
|||||||
@ -1,108 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.tools.__base.tool import Tool
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType
|
|
||||||
from core.tools.tool_manager import ToolManager
|
|
||||||
|
|
||||||
OUTPUT_TOOL_PROVIDER = "agent_output"
|
|
||||||
|
|
||||||
OUTPUT_TEXT_TOOL = "output_text"
|
|
||||||
FINAL_OUTPUT_TOOL = "final_output_answer"
|
|
||||||
FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output"
|
|
||||||
ILLEGAL_OUTPUT_TOOL = "illegal_output"
|
|
||||||
|
|
||||||
OUTPUT_TOOL_NAMES: Sequence[str] = (
|
|
||||||
OUTPUT_TEXT_TOOL,
|
|
||||||
FINAL_OUTPUT_TOOL,
|
|
||||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
|
||||||
ILLEGAL_OUTPUT_TOOL,
|
|
||||||
)
|
|
||||||
|
|
||||||
TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL)
|
|
||||||
|
|
||||||
|
|
||||||
TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session."
|
|
||||||
|
|
||||||
|
|
||||||
def is_terminal_output_tool(tool_name: str) -> bool:
|
|
||||||
return tool_name in TERMINAL_OUTPUT_TOOL_NAMES
|
|
||||||
|
|
||||||
|
|
||||||
def get_terminal_tool_name(structured_output_enabled: bool) -> str:
|
|
||||||
return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL
|
|
||||||
|
|
||||||
|
|
||||||
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
|
|
||||||
|
|
||||||
|
|
||||||
def build_agent_output_tools(
|
|
||||||
tenant_id: str,
|
|
||||||
invoke_from: InvokeFrom,
|
|
||||||
tool_invoke_from: ToolInvokeFrom,
|
|
||||||
structured_output_schema: Mapping[str, Any] | None = None,
|
|
||||||
) -> list[Tool]:
|
|
||||||
def get_tool_runtime(_tool_name: str) -> Tool:
|
|
||||||
return ToolManager.get_tool_runtime(
|
|
||||||
provider_type=ToolProviderType.BUILT_IN,
|
|
||||||
provider_id=OUTPUT_TOOL_PROVIDER,
|
|
||||||
tool_name=_tool_name,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
tool_invoke_from=tool_invoke_from,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools: list[Tool] = [
|
|
||||||
get_tool_runtime(OUTPUT_TEXT_TOOL),
|
|
||||||
get_tool_runtime(ILLEGAL_OUTPUT_TOOL),
|
|
||||||
]
|
|
||||||
|
|
||||||
if structured_output_schema:
|
|
||||||
raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
|
|
||||||
raw_tool.entity = raw_tool.entity.model_copy(deep=True)
|
|
||||||
data_parameter = ToolParameter(
|
|
||||||
name="data",
|
|
||||||
type=ToolParameter.ToolParameterType.OBJECT,
|
|
||||||
form=ToolParameter.ToolParameterForm.LLM,
|
|
||||||
required=True,
|
|
||||||
input_schema=dict(structured_output_schema),
|
|
||||||
label=I18nObject(en_US="__Data", zh_Hans="__Data"),
|
|
||||||
)
|
|
||||||
raw_tool.entity.parameters = [data_parameter]
|
|
||||||
|
|
||||||
def invoke_tool(
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
app_id: str | None = None,
|
|
||||||
message_id: str | None = None,
|
|
||||||
) -> ToolInvokeMessage:
|
|
||||||
data = tool_parameters["data"]
|
|
||||||
if not data:
|
|
||||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required"))
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict"))
|
|
||||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
|
|
||||||
|
|
||||||
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
|
|
||||||
tools.append(raw_tool)
|
|
||||||
else:
|
|
||||||
raw_tool = get_tool_runtime(FINAL_OUTPUT_TOOL)
|
|
||||||
|
|
||||||
def invoke_tool(
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
app_id: str | None = None,
|
|
||||||
message_id: str | None = None,
|
|
||||||
) -> ToolInvokeMessage:
|
|
||||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
|
|
||||||
|
|
||||||
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
|
|
||||||
tools.append(raw_tool)
|
|
||||||
|
|
||||||
return tools
|
|
||||||
@ -10,7 +10,6 @@ from collections.abc import Callable, Generator
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||||
from core.agent.output_tools import ILLEGAL_OUTPUT_TOOL
|
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
@ -57,7 +56,8 @@ class AgentPattern(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(
|
def run(
|
||||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
|
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
|
||||||
|
stream: bool = True,
|
||||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||||
"""Execute the agent strategy."""
|
"""Execute the agent strategy."""
|
||||||
pass
|
pass
|
||||||
@ -462,8 +462,6 @@ class AgentPattern(ABC):
|
|||||||
"""Convert tools to prompt message format."""
|
"""Convert tools to prompt message format."""
|
||||||
prompt_tools: list[PromptMessageTool] = []
|
prompt_tools: list[PromptMessageTool] = []
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
if tool.entity.identity.name == ILLEGAL_OUTPUT_TOOL:
|
|
||||||
continue
|
|
||||||
prompt_tools.append(tool.to_prompt_message_tool())
|
prompt_tools.append(tool.to_prompt_message_tool())
|
||||||
return prompt_tools
|
return prompt_tools
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,10 @@
|
|||||||
"""Function Call strategy implementation."""
|
"""Function Call strategy implementation."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Literal, Protocol, Union, cast
|
from typing import Any, Union
|
||||||
|
|
||||||
from core.agent.entities import AgentLog, AgentResult
|
from core.agent.entities import AgentLog, AgentResult
|
||||||
from core.agent.output_tools import (
|
|
||||||
ILLEGAL_OUTPUT_TOOL,
|
|
||||||
TERMINAL_OUTPUT_MESSAGE,
|
|
||||||
)
|
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -30,7 +25,8 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
"""Function Call strategy using model's native tool calling capability."""
|
"""Function Call strategy using model's native tool calling capability."""
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
|
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
|
||||||
|
stream: bool = True,
|
||||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||||
"""Execute the function call agent strategy."""
|
"""Execute the function call agent strategy."""
|
||||||
# Convert tools to prompt format
|
# Convert tools to prompt format
|
||||||
@ -43,23 +39,9 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||||
final_text: str = ""
|
final_text: str = ""
|
||||||
final_tool_args: dict[str, Any] = {"!!!": "!!!"}
|
|
||||||
finish_reason: str | None = None
|
finish_reason: str | None = None
|
||||||
output_files: list[File] = [] # Track files produced by tools
|
output_files: list[File] = [] # Track files produced by tools
|
||||||
|
|
||||||
class _LLMInvoker(Protocol):
|
|
||||||
def invoke_llm(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
prompt_messages: list[PromptMessage],
|
|
||||||
model_parameters: dict[str, Any],
|
|
||||||
tools: list[PromptMessageTool],
|
|
||||||
stop: list[str],
|
|
||||||
stream: Literal[False],
|
|
||||||
user: str | None,
|
|
||||||
callbacks: list[Any],
|
|
||||||
) -> LLMResult: ...
|
|
||||||
|
|
||||||
while function_call_state and iteration_step <= max_iterations:
|
while function_call_state and iteration_step <= max_iterations:
|
||||||
function_call_state = False
|
function_call_state = False
|
||||||
round_log = self._create_log(
|
round_log = self._create_log(
|
||||||
@ -69,7 +51,8 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
yield round_log
|
yield round_log
|
||||||
|
# On last iteration, remove tools to force final answer
|
||||||
|
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||||
model_log = self._create_log(
|
model_log = self._create_log(
|
||||||
label=f"{self.model_instance.model} Thought",
|
label=f"{self.model_instance.model} Thought",
|
||||||
log_type=AgentLog.LogType.THOUGHT,
|
log_type=AgentLog.LogType.THOUGHT,
|
||||||
@ -86,63 +69,47 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||||
|
|
||||||
# Invoke model
|
# Invoke model
|
||||||
invoker = cast(_LLMInvoker, self.model_instance)
|
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||||
chunks = invoker.invoke_llm(
|
|
||||||
prompt_messages=messages,
|
prompt_messages=messages,
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
tools=prompt_tools,
|
tools=current_tools,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=False,
|
stream=stream,
|
||||||
user=self.context.user_id,
|
user=self.context.user_id,
|
||||||
callbacks=[],
|
callbacks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process response
|
# Process response
|
||||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||||
chunks, round_usage, model_log, emit_chunks=False
|
chunks, round_usage, model_log
|
||||||
)
|
)
|
||||||
|
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||||
if response_content:
|
|
||||||
replaced_tool_call = (
|
|
||||||
str(uuid.uuid4()),
|
|
||||||
ILLEGAL_OUTPUT_TOOL,
|
|
||||||
{
|
|
||||||
"raw": response_content,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
tool_calls.append(replaced_tool_call)
|
|
||||||
|
|
||||||
messages.append(self._create_assistant_message("", tool_calls))
|
|
||||||
|
|
||||||
# Accumulate to total usage
|
# Accumulate to total usage
|
||||||
round_usage_value = round_usage.get("usage")
|
round_usage_value = round_usage.get("usage")
|
||||||
if round_usage_value:
|
if round_usage_value:
|
||||||
self._accumulate_usage(total_usage, round_usage_value)
|
self._accumulate_usage(total_usage, round_usage_value)
|
||||||
|
|
||||||
|
# Update final text if no tool calls (this is likely the final answer)
|
||||||
|
if not tool_calls:
|
||||||
|
final_text = response_content
|
||||||
|
|
||||||
# Update finish reason
|
# Update finish reason
|
||||||
if chunk_finish_reason:
|
if chunk_finish_reason:
|
||||||
finish_reason = chunk_finish_reason
|
finish_reason = chunk_finish_reason
|
||||||
|
|
||||||
assert len(tool_calls) > 0
|
|
||||||
|
|
||||||
# Process tool calls
|
# Process tool calls
|
||||||
tool_outputs: dict[str, str] = {}
|
tool_outputs: dict[str, str] = {}
|
||||||
function_call_state = True
|
if tool_calls:
|
||||||
# Execute tools
|
function_call_state = True
|
||||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
# Execute tools
|
||||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||||
tool_name, tool_args, tool_call_id, messages, round_log
|
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||||
)
|
tool_name, tool_args, tool_call_id, messages, round_log
|
||||||
tool_entity = self._find_tool_by_name(tool_name)
|
)
|
||||||
if not tool_entity:
|
tool_outputs[tool_name] = tool_response
|
||||||
raise ValueError(f"Tool {tool_name} not found")
|
# Track files produced by tools
|
||||||
tool_outputs[tool_name] = tool_response
|
output_files.extend(tool_files)
|
||||||
# Track files produced by tools
|
|
||||||
output_files.extend(tool_files)
|
|
||||||
if tool_response == TERMINAL_OUTPUT_MESSAGE:
|
|
||||||
function_call_state = False
|
|
||||||
final_tool_args = tool_entity.transform_tool_parameters_type(tool_args)
|
|
||||||
|
|
||||||
yield self._finish_log(
|
yield self._finish_log(
|
||||||
round_log,
|
round_log,
|
||||||
data={
|
data={
|
||||||
@ -161,19 +128,8 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
# Return final result
|
# Return final result
|
||||||
from core.agent.entities import AgentResult
|
from core.agent.entities import AgentResult
|
||||||
|
|
||||||
output_payload: str | dict
|
|
||||||
output_text = final_tool_args.get("text")
|
|
||||||
output_structured_payload = final_tool_args.get("data")
|
|
||||||
|
|
||||||
if isinstance(output_structured_payload, dict):
|
|
||||||
output_payload = output_structured_payload
|
|
||||||
elif isinstance(output_text, str):
|
|
||||||
output_payload = output_text
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Final output ({final_tool_args}) is not a string or structured data.")
|
|
||||||
|
|
||||||
return AgentResult(
|
return AgentResult(
|
||||||
output=output_payload,
|
text=final_text,
|
||||||
files=output_files,
|
files=output_files,
|
||||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
@ -184,8 +140,6 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||||
llm_usage: dict[str, LLMUsage | None],
|
llm_usage: dict[str, LLMUsage | None],
|
||||||
start_log: AgentLog,
|
start_log: AgentLog,
|
||||||
*,
|
|
||||||
emit_chunks: bool,
|
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
LLMResultChunk | AgentLog,
|
LLMResultChunk | AgentLog,
|
||||||
None,
|
None,
|
||||||
@ -217,8 +171,7 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
if chunk.delta.finish_reason:
|
if chunk.delta.finish_reason:
|
||||||
finish_reason = chunk.delta.finish_reason
|
finish_reason = chunk.delta.finish_reason
|
||||||
|
|
||||||
if emit_chunks:
|
yield chunk
|
||||||
yield chunk
|
|
||||||
else:
|
else:
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
result: LLMResult = chunks
|
result: LLMResult = chunks
|
||||||
@ -233,12 +186,11 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
self._accumulate_usage(llm_usage, result.usage)
|
self._accumulate_usage(llm_usage, result.usage)
|
||||||
|
|
||||||
# Convert to streaming format
|
# Convert to streaming format
|
||||||
if emit_chunks:
|
yield LLMResultChunk(
|
||||||
yield LLMResultChunk(
|
model=result.model,
|
||||||
model=result.model,
|
prompt_messages=result.prompt_messages,
|
||||||
prompt_messages=result.prompt_messages,
|
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
)
|
||||||
)
|
|
||||||
yield self._finish_log(
|
yield self._finish_log(
|
||||||
start_log,
|
start_log,
|
||||||
data={
|
data={
|
||||||
@ -248,14 +200,6 @@ class FunctionCallStrategy(AgentPattern):
|
|||||||
)
|
)
|
||||||
return tool_calls, response_content, finish_reason
|
return tool_calls, response_content, finish_reason
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_output_text(value: Any) -> str:
|
|
||||||
if value is None:
|
|
||||||
return ""
|
|
||||||
if isinstance(value, str):
|
|
||||||
return value
|
|
||||||
return json.dumps(value, ensure_ascii=False)
|
|
||||||
|
|
||||||
def _create_assistant_message(
|
def _create_assistant_message(
|
||||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||||
) -> AssistantPromptMessage:
|
) -> AssistantPromptMessage:
|
||||||
|
|||||||
@ -4,17 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||||
from core.agent.output_tools import (
|
|
||||||
FINAL_OUTPUT_TOOL,
|
|
||||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
|
||||||
ILLEGAL_OUTPUT_TOOL,
|
|
||||||
OUTPUT_TEXT_TOOL,
|
|
||||||
OUTPUT_TOOL_NAME_SET,
|
|
||||||
)
|
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
@ -59,7 +52,8 @@ class ReActStrategy(AgentPattern):
|
|||||||
self.instruction = instruction
|
self.instruction = instruction
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
|
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
|
||||||
|
stream: bool = True,
|
||||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||||
"""Execute the ReAct agent strategy."""
|
"""Execute the ReAct agent strategy."""
|
||||||
# Initialize tracking
|
# Initialize tracking
|
||||||
@ -71,19 +65,6 @@ class ReActStrategy(AgentPattern):
|
|||||||
output_files: list[File] = [] # Track files produced by tools
|
output_files: list[File] = [] # Track files produced by tools
|
||||||
final_text: str = ""
|
final_text: str = ""
|
||||||
finish_reason: str | None = None
|
finish_reason: str | None = None
|
||||||
tool_instance_names = {tool.entity.identity.name for tool in self.tools}
|
|
||||||
available_output_tool_names = {
|
|
||||||
tool_name
|
|
||||||
for tool_name in tool_instance_names
|
|
||||||
if tool_name in OUTPUT_TOOL_NAME_SET and tool_name != ILLEGAL_OUTPUT_TOOL
|
|
||||||
}
|
|
||||||
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
|
|
||||||
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
|
||||||
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
|
|
||||||
terminal_tool_name = FINAL_OUTPUT_TOOL
|
|
||||||
else:
|
|
||||||
raise ValueError("No terminal output tool configured")
|
|
||||||
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
|
|
||||||
|
|
||||||
# Add "Observation" to stop sequences
|
# Add "Observation" to stop sequences
|
||||||
if "Observation" not in stop:
|
if "Observation" not in stop:
|
||||||
@ -100,15 +81,10 @@ class ReActStrategy(AgentPattern):
|
|||||||
)
|
)
|
||||||
yield round_log
|
yield round_log
|
||||||
|
|
||||||
# Build prompt with tool restrictions on last iteration
|
# Build prompt with/without tools based on iteration
|
||||||
if iteration_step == max_iterations:
|
include_tools = iteration_step < max_iterations
|
||||||
tools_for_prompt = [
|
|
||||||
tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL]
|
|
||||||
current_messages = self._build_prompt_with_react_format(
|
current_messages = self._build_prompt_with_react_format(
|
||||||
prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction
|
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||||
)
|
)
|
||||||
|
|
||||||
model_log = self._create_log(
|
model_log = self._create_log(
|
||||||
@ -130,18 +106,18 @@ class ReActStrategy(AgentPattern):
|
|||||||
messages_to_use = current_messages
|
messages_to_use = current_messages
|
||||||
|
|
||||||
# Invoke model
|
# Invoke model
|
||||||
chunks = self.model_instance.invoke_llm(
|
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||||
prompt_messages=messages_to_use,
|
prompt_messages=messages_to_use,
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stream=False,
|
stream=stream,
|
||||||
user=self.context.user_id or "",
|
user=self.context.user_id or "",
|
||||||
callbacks=[],
|
callbacks=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process response
|
# Process response
|
||||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||||
chunks, round_usage, model_log, current_messages, emit_chunks=False
|
chunks, round_usage, model_log, current_messages
|
||||||
)
|
)
|
||||||
agent_scratchpad.append(scratchpad)
|
agent_scratchpad.append(scratchpad)
|
||||||
|
|
||||||
@ -155,44 +131,28 @@ class ReActStrategy(AgentPattern):
|
|||||||
finish_reason = chunk_finish_reason
|
finish_reason = chunk_finish_reason
|
||||||
|
|
||||||
# Check if we have an action to execute
|
# Check if we have an action to execute
|
||||||
if scratchpad.action is None:
|
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||||
if not allow_illegal_output:
|
|
||||||
raise ValueError("Model did not call any tools")
|
|
||||||
illegal_action = AgentScratchpadUnit.Action(
|
|
||||||
action_name=ILLEGAL_OUTPUT_TOOL,
|
|
||||||
action_input={"raw": scratchpad.thought or ""},
|
|
||||||
)
|
|
||||||
scratchpad.action = illegal_action
|
|
||||||
scratchpad.action_str = illegal_action.model_dump_json()
|
|
||||||
react_state = True
|
react_state = True
|
||||||
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
|
# Execute tool
|
||||||
scratchpad.observation = observation
|
|
||||||
output_files.extend(tool_files)
|
|
||||||
else:
|
|
||||||
action_name = scratchpad.action.action_name
|
|
||||||
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
|
||||||
pass # output_text_payload = scratchpad.action.action_input.get("text")
|
|
||||||
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
|
||||||
data = scratchpad.action.action_input.get("data")
|
|
||||||
if isinstance(data, dict):
|
|
||||||
pass # structured_output_payload = data
|
|
||||||
elif action_name == FINAL_OUTPUT_TOOL:
|
|
||||||
if isinstance(scratchpad.action.action_input, dict):
|
|
||||||
final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
|
|
||||||
else:
|
|
||||||
final_text = self._format_output_text(scratchpad.action.action_input)
|
|
||||||
|
|
||||||
observation, tool_files = yield from self._handle_tool_call(
|
observation, tool_files = yield from self._handle_tool_call(
|
||||||
scratchpad.action, current_messages, round_log
|
scratchpad.action, current_messages, round_log
|
||||||
)
|
)
|
||||||
scratchpad.observation = observation
|
scratchpad.observation = observation
|
||||||
|
# Track files produced by tools
|
||||||
output_files.extend(tool_files)
|
output_files.extend(tool_files)
|
||||||
|
|
||||||
if action_name == terminal_tool_name:
|
# Add observation to scratchpad for display
|
||||||
pass # terminal_output_seen = True
|
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||||
react_state = False
|
else:
|
||||||
else:
|
# Extract final answer
|
||||||
react_state = True
|
if scratchpad.action and scratchpad.action.action_input:
|
||||||
|
final_answer = scratchpad.action.action_input
|
||||||
|
if isinstance(final_answer, dict):
|
||||||
|
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||||
|
final_text = str(final_answer)
|
||||||
|
elif scratchpad.thought:
|
||||||
|
# If no action but we have thought, use thought as final answer
|
||||||
|
final_text = scratchpad.thought
|
||||||
|
|
||||||
yield self._finish_log(
|
yield self._finish_log(
|
||||||
round_log,
|
round_log,
|
||||||
@ -208,22 +168,17 @@ class ReActStrategy(AgentPattern):
|
|||||||
|
|
||||||
# Return final result
|
# Return final result
|
||||||
|
|
||||||
output_payload: str | dict
|
from core.agent.entities import AgentResult
|
||||||
|
|
||||||
# TODO
|
|
||||||
|
|
||||||
return AgentResult(
|
return AgentResult(
|
||||||
output=output_payload,
|
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||||
files=output_files,
|
|
||||||
usage=total_usage.get("usage"),
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_prompt_with_react_format(
|
def _build_prompt_with_react_format(
|
||||||
self,
|
self,
|
||||||
original_messages: list[PromptMessage],
|
original_messages: list[PromptMessage],
|
||||||
agent_scratchpad: list[AgentScratchpadUnit],
|
agent_scratchpad: list[AgentScratchpadUnit],
|
||||||
tools: list[Tool] | None,
|
include_tools: bool = True,
|
||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""Build prompt messages with ReAct format."""
|
"""Build prompt messages with ReAct format."""
|
||||||
@ -240,13 +195,9 @@ class ReActStrategy(AgentPattern):
|
|||||||
# Format tools
|
# Format tools
|
||||||
tools_str = ""
|
tools_str = ""
|
||||||
tool_names = []
|
tool_names = []
|
||||||
if tools:
|
if include_tools and self.tools:
|
||||||
# Convert tools to prompt message tools format
|
# Convert tools to prompt message tools format
|
||||||
prompt_tools = [
|
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||||
tool.to_prompt_message_tool()
|
|
||||||
for tool in tools
|
|
||||||
if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL
|
|
||||||
]
|
|
||||||
tool_names = [tool.name for tool in prompt_tools]
|
tool_names = [tool.name for tool in prompt_tools]
|
||||||
|
|
||||||
# Format tools as JSON for comprehensive information
|
# Format tools as JSON for comprehensive information
|
||||||
@ -258,19 +209,12 @@ class ReActStrategy(AgentPattern):
|
|||||||
tools_str = "No tools available"
|
tools_str = "No tools available"
|
||||||
tool_names_str = ""
|
tool_names_str = ""
|
||||||
|
|
||||||
final_tool_name = FINAL_OUTPUT_TOOL
|
|
||||||
if FINAL_STRUCTURED_OUTPUT_TOOL in tool_names:
|
|
||||||
final_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
|
||||||
if final_tool_name not in tool_names:
|
|
||||||
raise ValueError("No terminal output tool available for prompt")
|
|
||||||
|
|
||||||
# Replace placeholders in the existing system prompt
|
# Replace placeholders in the existing system prompt
|
||||||
updated_content = msg.content
|
updated_content = msg.content
|
||||||
assert isinstance(updated_content, str)
|
assert isinstance(updated_content, str)
|
||||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||||
updated_content = updated_content.replace("{{final_tool_name}}", final_tool_name)
|
|
||||||
|
|
||||||
# Create new SystemPromptMessage with updated content
|
# Create new SystemPromptMessage with updated content
|
||||||
messages[i] = SystemPromptMessage(content=updated_content)
|
messages[i] = SystemPromptMessage(content=updated_content)
|
||||||
@ -302,12 +246,10 @@ class ReActStrategy(AgentPattern):
|
|||||||
|
|
||||||
def _handle_chunks(
|
def _handle_chunks(
|
||||||
self,
|
self,
|
||||||
chunks: LLMResult,
|
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||||
llm_usage: dict[str, Any],
|
llm_usage: dict[str, Any],
|
||||||
model_log: AgentLog,
|
model_log: AgentLog,
|
||||||
current_messages: list[PromptMessage],
|
current_messages: list[PromptMessage],
|
||||||
*,
|
|
||||||
emit_chunks: bool,
|
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
LLMResultChunk | AgentLog,
|
LLMResultChunk | AgentLog,
|
||||||
None,
|
None,
|
||||||
@ -319,20 +261,25 @@ class ReActStrategy(AgentPattern):
|
|||||||
"""
|
"""
|
||||||
usage_dict: dict[str, Any] = {}
|
usage_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
# Convert non-streaming to streaming format if needed
|
||||||
yield LLMResultChunk(
|
if isinstance(chunks, LLMResult):
|
||||||
model=chunks.model,
|
# Create a generator from the LLMResult
|
||||||
prompt_messages=chunks.prompt_messages,
|
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||||
delta=LLMResultChunkDelta(
|
yield LLMResultChunk(
|
||||||
index=0,
|
model=chunks.model,
|
||||||
message=chunks.message,
|
prompt_messages=chunks.prompt_messages,
|
||||||
usage=chunks.usage,
|
delta=LLMResultChunkDelta(
|
||||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
index=0,
|
||||||
),
|
message=chunks.message,
|
||||||
system_fingerprint=chunks.system_fingerprint or "",
|
usage=chunks.usage,
|
||||||
)
|
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||||
|
),
|
||||||
|
system_fingerprint=chunks.system_fingerprint or "",
|
||||||
|
)
|
||||||
|
|
||||||
streaming_chunks = result_to_chunks()
|
streaming_chunks = result_to_chunks()
|
||||||
|
else:
|
||||||
|
streaming_chunks = chunks
|
||||||
|
|
||||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||||
|
|
||||||
@ -356,18 +303,14 @@ class ReActStrategy(AgentPattern):
|
|||||||
scratchpad.action_str = action_str
|
scratchpad.action_str = action_str
|
||||||
scratchpad.action = chunk
|
scratchpad.action = chunk
|
||||||
|
|
||||||
if emit_chunks:
|
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
else:
|
||||||
elif isinstance(chunk, str):
|
|
||||||
# Text chunk
|
# Text chunk
|
||||||
chunk_text = str(chunk)
|
chunk_text = str(chunk)
|
||||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||||
|
|
||||||
if emit_chunks:
|
yield self._create_text_chunk(chunk_text, current_messages)
|
||||||
yield self._create_text_chunk(chunk_text, current_messages)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected chunk type: {type(chunk)}")
|
|
||||||
|
|
||||||
# Update usage
|
# Update usage
|
||||||
if usage_dict.get("usage"):
|
if usage_dict.get("usage"):
|
||||||
@ -391,14 +334,6 @@ class ReActStrategy(AgentPattern):
|
|||||||
|
|
||||||
return scratchpad, finish_reason
|
return scratchpad, finish_reason
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_output_text(value: Any) -> str:
|
|
||||||
if value is None:
|
|
||||||
return ""
|
|
||||||
if isinstance(value, str):
|
|
||||||
return value
|
|
||||||
return json.dumps(value, ensure_ascii=False)
|
|
||||||
|
|
||||||
def _handle_tool_call(
|
def _handle_tool_call(
|
||||||
self,
|
self,
|
||||||
action: AgentScratchpadUnit.Action,
|
action: AgentScratchpadUnit.Action,
|
||||||
|
|||||||
@ -2,17 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from typing import TYPE_CHECKING
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
from core.agent.entities import AgentEntity, ExecutionContext
|
from core.agent.entities import AgentEntity, ExecutionContext
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
|
|
||||||
from ...app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from ...tools.entities.tool_entities import ToolInvokeFrom
|
|
||||||
from ..output_tools import build_agent_output_tools
|
|
||||||
from .base import AgentPattern, ToolInvokeHook
|
from .base import AgentPattern, ToolInvokeHook
|
||||||
from .function_call import FunctionCallStrategy
|
from .function_call import FunctionCallStrategy
|
||||||
from .react import ReActStrategy
|
from .react import ReActStrategy
|
||||||
@ -29,10 +25,6 @@ class StrategyFactory:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_strategy(
|
def create_strategy(
|
||||||
*,
|
|
||||||
tenant_id: str,
|
|
||||||
invoke_from: InvokeFrom,
|
|
||||||
tool_invoke_from: ToolInvokeFrom,
|
|
||||||
model_features: list[ModelFeature],
|
model_features: list[ModelFeature],
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
context: ExecutionContext,
|
context: ExecutionContext,
|
||||||
@ -43,16 +35,11 @@ class StrategyFactory:
|
|||||||
agent_strategy: AgentEntity.Strategy | None = None,
|
agent_strategy: AgentEntity.Strategy | None = None,
|
||||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||||
instruction: str = "",
|
instruction: str = "",
|
||||||
structured_output_schema: Mapping[str, Any] | None = None,
|
|
||||||
) -> AgentPattern:
|
) -> AgentPattern:
|
||||||
"""
|
"""
|
||||||
Create an appropriate strategy based on model features.
|
Create an appropriate strategy based on model features.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id:
|
|
||||||
invoke_from:
|
|
||||||
tool_invoke_from:
|
|
||||||
structured_output_schema:
|
|
||||||
model_features: List of model features/capabilities
|
model_features: List of model features/capabilities
|
||||||
model_instance: Model instance to use
|
model_instance: Model instance to use
|
||||||
context: Execution context containing trace/audit information
|
context: Execution context containing trace/audit information
|
||||||
@ -67,14 +54,6 @@ class StrategyFactory:
|
|||||||
Returns:
|
Returns:
|
||||||
AgentStrategy instance
|
AgentStrategy instance
|
||||||
"""
|
"""
|
||||||
output_tools = build_agent_output_tools(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
tool_invoke_from=tool_invoke_from,
|
|
||||||
structured_output_schema=structured_output_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools.extend(output_tools)
|
|
||||||
|
|
||||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||||
|
|||||||
@ -7,7 +7,7 @@ You have access to the following tools:
|
|||||||
{{tools}}
|
{{tools}}
|
||||||
|
|
||||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
|
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||||
|
|
||||||
Provide only ONE action per $JSON_BLOB, as shown:
|
Provide only ONE action per $JSON_BLOB, as shown:
|
||||||
|
|
||||||
@ -32,14 +32,12 @@ Thought: I know what to respond
|
|||||||
Action:
|
Action:
|
||||||
```
|
```
|
||||||
{
|
{
|
||||||
"action": "{{final_tool_name}}",
|
"action": "Final Answer",
|
||||||
"action_input": {
|
"action_input": "Final response to human"
|
||||||
"text": "Final response to human"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||||
{{historic_messages}}
|
{{historic_messages}}
|
||||||
Question: {{query}}
|
Question: {{query}}
|
||||||
{{agent_scratchpad}}
|
{{agent_scratchpad}}
|
||||||
@ -58,7 +56,7 @@ You have access to the following tools:
|
|||||||
{{tools}}
|
{{tools}}
|
||||||
|
|
||||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||||
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
|
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||||
|
|
||||||
Provide only ONE action per $JSON_BLOB, as shown:
|
Provide only ONE action per $JSON_BLOB, as shown:
|
||||||
|
|
||||||
@ -83,14 +81,12 @@ Thought: I know what to respond
|
|||||||
Action:
|
Action:
|
||||||
```
|
```
|
||||||
{
|
{
|
||||||
"action": "{{final_tool_name}}",
|
"action": "Final Answer",
|
||||||
"action_input": {
|
"action_input": "Final response to human"
|
||||||
"text": "Final response to human"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -477,6 +477,7 @@ class LLMGenerator:
|
|||||||
prompt_messages=complete_messages,
|
prompt_messages=complete_messages,
|
||||||
output_model=CodeNodeStructuredOutput,
|
output_model=CodeNodeStructuredOutput,
|
||||||
model_parameters=model_parameters,
|
model_parameters=model_parameters,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -557,14 +558,10 @@ class LLMGenerator:
|
|||||||
|
|
||||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||||
try:
|
try:
|
||||||
response = invoke_llm_with_pydantic_model(
|
response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
|
||||||
provider=model_instance.provider,
|
model_instance=model_instance, prompt_messages=prompt_messages,
|
||||||
model_schema=model_schema,
|
output_model=SuggestedQuestionsOutput,
|
||||||
model_instance=model_instance,
|
model_parameters=completion_params, tenant_id=tenant_id)
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
output_model=SuggestedQuestionsOutput,
|
|
||||||
model_parameters=completion_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"questions": response.questions, "error": ""}
|
return {"questions": response.questions, "error": ""}
|
||||||
|
|
||||||
|
|||||||
@ -1,190 +1,188 @@
|
|||||||
from collections.abc import Callable, Mapping, Sequence
|
"""
|
||||||
from typing import Any, cast
|
File reference detection and conversion for structured output.
|
||||||
|
|
||||||
|
This module provides utilities to:
|
||||||
|
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||||
|
2. Convert file ID strings to File objects after LLM returns
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||||
|
from factories.file_factory import build_from_mapping
|
||||||
|
|
||||||
FILE_PATH_SCHEMA_TYPE = "file"
|
FILE_REF_FORMAT = "dify-file-ref"
|
||||||
FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"}
|
|
||||||
FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)."
|
|
||||||
|
|
||||||
|
|
||||||
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
|
def is_file_ref_property(schema: dict) -> bool:
|
||||||
if schema.get("type") == FILE_PATH_SCHEMA_TYPE:
|
"""Check if a schema property is a file reference."""
|
||||||
return True
|
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||||
format_value = schema.get("format")
|
|
||||||
if not isinstance(format_value, str):
|
|
||||||
return False
|
|
||||||
normalized_format = format_value.lower().replace("_", "-")
|
|
||||||
return normalized_format in FILE_PATH_SCHEMA_FORMATS
|
|
||||||
|
|
||||||
|
|
||||||
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||||
file_path_fields: list[str] = []
|
"""
|
||||||
|
Recursively detect file reference fields in schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON Schema to analyze
|
||||||
|
path: Current path in the schema (used for recursion)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||||
|
"""
|
||||||
|
file_ref_paths: list[str] = []
|
||||||
schema_type = schema.get("type")
|
schema_type = schema.get("type")
|
||||||
|
|
||||||
if schema_type == "object":
|
if schema_type == "object":
|
||||||
properties = schema.get("properties")
|
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||||
if isinstance(properties, Mapping):
|
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||||
properties_mapping = cast(Mapping[str, Any], properties)
|
|
||||||
for prop_name, prop_schema in properties_mapping.items():
|
|
||||||
if not isinstance(prop_schema, Mapping):
|
|
||||||
continue
|
|
||||||
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
|
|
||||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
|
||||||
|
|
||||||
if is_file_path_property(prop_schema_mapping):
|
if is_file_ref_property(prop_schema):
|
||||||
file_path_fields.append(current_path)
|
file_ref_paths.append(current_path)
|
||||||
else:
|
elif isinstance(prop_schema, dict):
|
||||||
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
|
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||||
|
|
||||||
elif schema_type == "array":
|
elif schema_type == "array":
|
||||||
items_schema = schema.get("items")
|
items_schema = schema.get("items", {})
|
||||||
if not isinstance(items_schema, Mapping):
|
|
||||||
return file_path_fields
|
|
||||||
items_schema_mapping = cast(Mapping[str, Any], items_schema)
|
|
||||||
array_path = f"{path}[*]" if path else "[*]"
|
array_path = f"{path}[*]" if path else "[*]"
|
||||||
|
|
||||||
if is_file_path_property(items_schema_mapping):
|
if is_file_ref_property(items_schema):
|
||||||
file_path_fields.append(array_path)
|
file_ref_paths.append(array_path)
|
||||||
else:
|
elif isinstance(items_schema, dict):
|
||||||
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
|
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||||
|
|
||||||
return file_path_fields
|
return file_ref_paths
|
||||||
|
|
||||||
|
|
||||||
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
def convert_file_refs_in_output(
|
||||||
result = _deep_copy_value(schema)
|
|
||||||
if not isinstance(result, dict):
|
|
||||||
raise ValueError("structured_output_schema must be a JSON object")
|
|
||||||
result_dict = cast(dict[str, Any], result)
|
|
||||||
|
|
||||||
file_path_fields: list[str] = []
|
|
||||||
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
|
|
||||||
return result_dict, file_path_fields
|
|
||||||
|
|
||||||
|
|
||||||
def convert_sandbox_file_paths_in_output(
|
|
||||||
output: Mapping[str, Any],
|
output: Mapping[str, Any],
|
||||||
file_path_fields: Sequence[str],
|
json_schema: Mapping[str, Any],
|
||||||
file_resolver: Callable[[str], File],
|
tenant_id: str,
|
||||||
) -> tuple[dict[str, Any], list[File]]:
|
) -> dict[str, Any]:
|
||||||
if not file_path_fields:
|
"""
|
||||||
return dict(output), []
|
Convert file ID strings to File objects based on schema.
|
||||||
|
|
||||||
result = _deep_copy_value(output)
|
Args:
|
||||||
if not isinstance(result, dict):
|
output: The structured_output from LLM result
|
||||||
raise ValueError("Structured output must be a JSON object")
|
json_schema: The original JSON schema (to detect file ref fields)
|
||||||
result_dict = cast(dict[str, Any], result)
|
tenant_id: Tenant ID for file lookup
|
||||||
|
|
||||||
files: list[File] = []
|
Returns:
|
||||||
for path in file_path_fields:
|
Output with file references converted to File objects
|
||||||
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
|
"""
|
||||||
|
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||||
|
if not file_ref_paths:
|
||||||
|
return dict(output)
|
||||||
|
|
||||||
return result_dict, files
|
result = _deep_copy_dict(output)
|
||||||
|
|
||||||
|
for path in file_ref_paths:
|
||||||
|
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
|
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||||
schema_type = schema.get("type")
|
"""Deep copy a mapping to a mutable dict."""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
if schema_type == "object":
|
for key, value in obj.items():
|
||||||
properties = schema.get("properties")
|
if isinstance(value, Mapping):
|
||||||
if isinstance(properties, Mapping):
|
result[key] = _deep_copy_dict(value)
|
||||||
properties_mapping = cast(Mapping[str, Any], properties)
|
elif isinstance(value, list):
|
||||||
for prop_name, prop_schema in properties_mapping.items():
|
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||||
if not isinstance(prop_schema, dict):
|
|
||||||
continue
|
|
||||||
prop_schema_dict = cast(dict[str, Any], prop_schema)
|
|
||||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
|
||||||
|
|
||||||
if is_file_path_property(prop_schema_dict):
|
|
||||||
_normalize_file_path_schema(prop_schema_dict)
|
|
||||||
file_path_fields.append(current_path)
|
|
||||||
else:
|
|
||||||
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
|
|
||||||
|
|
||||||
elif schema_type == "array":
|
|
||||||
items_schema = schema.get("items")
|
|
||||||
if not isinstance(items_schema, dict):
|
|
||||||
return
|
|
||||||
items_schema_dict = cast(dict[str, Any], items_schema)
|
|
||||||
array_path = f"{path}[*]" if path else "[*]"
|
|
||||||
|
|
||||||
if is_file_path_property(items_schema_dict):
|
|
||||||
_normalize_file_path_schema(items_schema_dict)
|
|
||||||
file_path_fields.append(array_path)
|
|
||||||
else:
|
else:
|
||||||
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
|
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||||
schema["type"] = "string"
|
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||||
schema.pop("format", None)
|
|
||||||
description = schema.get("description", "")
|
|
||||||
if description:
|
|
||||||
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
|
|
||||||
else:
|
|
||||||
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
|
|
||||||
|
|
||||||
|
|
||||||
def _deep_copy_value(value: Any) -> Any:
|
|
||||||
if isinstance(value, Mapping):
|
|
||||||
mapping = cast(Mapping[str, Any], value)
|
|
||||||
return {key: _deep_copy_value(item) for key, item in mapping.items()}
|
|
||||||
if isinstance(value, list):
|
|
||||||
list_value = cast(list[Any], value)
|
|
||||||
return [_deep_copy_value(item) for item in list_value]
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_path_in_place(
|
|
||||||
obj: dict[str, Any],
|
|
||||||
path_parts: list[str],
|
|
||||||
file_resolver: Callable[[str], File],
|
|
||||||
files: list[File],
|
|
||||||
) -> None:
|
|
||||||
if not path_parts:
|
if not path_parts:
|
||||||
return
|
return
|
||||||
|
|
||||||
current = path_parts[0]
|
current = path_parts[0]
|
||||||
remaining = path_parts[1:]
|
remaining = path_parts[1:]
|
||||||
|
|
||||||
|
# Handle array notation like "files[*]"
|
||||||
if current.endswith("[*]"):
|
if current.endswith("[*]"):
|
||||||
key = current[:-3] if current != "[*]" else ""
|
key = current[:-3] if current != "[*]" else None
|
||||||
target_value = obj.get(key) if key else obj
|
target = obj.get(key) if key else obj
|
||||||
|
|
||||||
if isinstance(target_value, list):
|
if isinstance(target, list):
|
||||||
target_list = cast(list[Any], target_value)
|
|
||||||
if remaining:
|
if remaining:
|
||||||
for item in target_list:
|
# Nested array with remaining path - recurse into each item
|
||||||
|
for item in target:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
item_dict = cast(dict[str, Any], item)
|
_convert_path_in_place(item, remaining, tenant_id)
|
||||||
_convert_path_in_place(item_dict, remaining, file_resolver, files)
|
|
||||||
else:
|
else:
|
||||||
resolved_files: list[File] = []
|
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||||
for item in target_list:
|
files: list[File] = []
|
||||||
if not isinstance(item, str):
|
for item in target:
|
||||||
raise ValueError("File path must be a string")
|
file = _convert_file_id(item, tenant_id)
|
||||||
file = file_resolver(item)
|
if file is not None:
|
||||||
files.append(file)
|
files.append(file)
|
||||||
resolved_files.append(file)
|
# Replace the array with ArrayFileSegment
|
||||||
if key:
|
if key:
|
||||||
obj[key] = ArrayFileSegment(value=resolved_files)
|
obj[key] = ArrayFileSegment(value=files)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not remaining:
|
if not remaining:
|
||||||
if current not in obj:
|
# Leaf node - convert the value and wrap in FileSegment
|
||||||
return
|
if current in obj:
|
||||||
value = obj[current]
|
file = _convert_file_id(obj[current], tenant_id)
|
||||||
if value is None:
|
if file is not None:
|
||||||
obj[current] = None
|
obj[current] = FileSegment(value=file)
|
||||||
return
|
else:
|
||||||
if not isinstance(value, str):
|
obj[current] = None
|
||||||
raise ValueError("File path must be a string")
|
else:
|
||||||
file = file_resolver(value)
|
# Recurse into nested object
|
||||||
files.append(file)
|
if current in obj and isinstance(obj[current], dict):
|
||||||
obj[current] = FileSegment(value=file)
|
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||||
return
|
|
||||||
|
|
||||||
if current in obj and isinstance(obj[current], dict):
|
|
||||||
_convert_path_in_place(obj[current], remaining, file_resolver, files)
|
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||||
|
"""
|
||||||
|
Convert a file ID string to a File object.
|
||||||
|
|
||||||
|
Tries multiple file sources in order:
|
||||||
|
1. ToolFile (files generated by tools/workflows)
|
||||||
|
2. UploadFile (files uploaded by users)
|
||||||
|
"""
|
||||||
|
if not isinstance(file_id, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate UUID format
|
||||||
|
try:
|
||||||
|
uuid.UUID(file_id)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try ToolFile first (files generated by tools/workflows)
|
||||||
|
try:
|
||||||
|
return build_from_mapping(
|
||||||
|
mapping={
|
||||||
|
"transfer_method": "tool_file",
|
||||||
|
"tool_file_id": file_id,
|
||||||
|
},
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try UploadFile (files uploaded by users)
|
||||||
|
try:
|
||||||
|
return build_from_mapping(
|
||||||
|
mapping={
|
||||||
|
"transfer_method": "local_file",
|
||||||
|
"upload_file_id": file_id,
|
||||||
|
},
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# File not found in any source
|
||||||
|
return None
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import json_repair
|
|||||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||||
|
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.output_parser.file_ref import detect_file_path_fields
|
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
@ -55,11 +55,12 @@ def invoke_llm_with_structured_output(
|
|||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
prompt_messages: Sequence[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
json_schema: Mapping[str, Any],
|
json_schema: Mapping[str, Any],
|
||||||
model_parameters: Mapping[str, Any] | None = None,
|
model_parameters: Mapping | None = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
callbacks: list[Callback] | None = None,
|
callbacks: list[Callback] | None = None,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> LLMResultWithStructuredOutput:
|
) -> LLMResultWithStructuredOutput:
|
||||||
"""
|
"""
|
||||||
Invoke large language model with structured output.
|
Invoke large language model with structured output.
|
||||||
@ -77,12 +78,14 @@ def invoke_llm_with_structured_output(
|
|||||||
:param stop: stop words
|
:param stop: stop words
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:param callbacks: callbacks
|
:param callbacks: callbacks
|
||||||
:return: response with structured output
|
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||||
|
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||||
|
file IDs in the output will be automatically converted to File objects.
|
||||||
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {})
|
model_parameters_with_json_schema: dict[str, Any] = {
|
||||||
|
**(model_parameters or {}),
|
||||||
if detect_file_path_fields(json_schema):
|
}
|
||||||
raise OutputParserError("Structured output file paths are only supported in sandbox mode.")
|
|
||||||
|
|
||||||
# Determine structured output strategy
|
# Determine structured output strategy
|
||||||
|
|
||||||
@ -119,6 +122,14 @@ def invoke_llm_with_structured_output(
|
|||||||
# Fill missing fields with default values
|
# Fill missing fields with default values
|
||||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||||
|
|
||||||
|
# Convert file references if tenant_id is provided
|
||||||
|
if tenant_id is not None:
|
||||||
|
structured_output = convert_file_refs_in_output(
|
||||||
|
output=structured_output,
|
||||||
|
json_schema=json_schema,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
return LLMResultWithStructuredOutput(
|
return LLMResultWithStructuredOutput(
|
||||||
structured_output=structured_output,
|
structured_output=structured_output,
|
||||||
model=llm_result.model,
|
model=llm_result.model,
|
||||||
@ -136,11 +147,12 @@ def invoke_llm_with_pydantic_model(
|
|||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
prompt_messages: Sequence[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
output_model: type[T],
|
output_model: type[T],
|
||||||
model_parameters: Mapping[str, Any] | None = None,
|
model_parameters: Mapping | None = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
callbacks: list[Callback] | None = None,
|
callbacks: list[Callback] | None = None,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Invoke large language model with a Pydantic output model.
|
Invoke large language model with a Pydantic output model.
|
||||||
@ -148,8 +160,11 @@ def invoke_llm_with_pydantic_model(
|
|||||||
This helper generates a JSON schema from the Pydantic model, invokes the
|
This helper generates a JSON schema from the Pydantic model, invokes the
|
||||||
structured-output LLM path, and validates the result.
|
structured-output LLM path, and validates the result.
|
||||||
|
|
||||||
The helper performs a non-streaming invocation and returns the validated
|
The stream parameter controls the underlying LLM invocation mode:
|
||||||
Pydantic model directly.
|
- stream=True (default): Uses streaming LLM call, consumes the generator internally
|
||||||
|
- stream=False: Uses non-streaming LLM call
|
||||||
|
|
||||||
|
In both cases, the function returns the validated Pydantic model directly.
|
||||||
"""
|
"""
|
||||||
json_schema = _schema_from_pydantic(output_model)
|
json_schema = _schema_from_pydantic(output_model)
|
||||||
|
|
||||||
@ -164,6 +179,7 @@ def invoke_llm_with_pydantic_model(
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
user=user,
|
user=user,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
structured_output = result.structured_output
|
structured_output = result.structured_output
|
||||||
@ -220,27 +236,25 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
|
|||||||
return _parse_structured_output(content)
|
return _parse_structured_output(content)
|
||||||
|
|
||||||
|
|
||||||
def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]:
|
def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
|
||||||
"""Parse JSON from tool call arguments."""
|
"""Parse JSON from tool call arguments."""
|
||||||
if not arguments:
|
if not arguments:
|
||||||
raise OutputParserError("Tool call arguments is empty")
|
raise OutputParserError("Tool call arguments is empty")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_any = json.loads(arguments)
|
parsed = json.loads(arguments)
|
||||||
if not isinstance(parsed_any, dict):
|
if not isinstance(parsed, dict):
|
||||||
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
|
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
|
||||||
parsed = cast(dict[str, Any], parsed_any)
|
|
||||||
return parsed
|
return parsed
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Try to repair malformed JSON
|
# Try to repair malformed JSON
|
||||||
repaired_any = json_repair.loads(arguments)
|
repaired = json_repair.loads(arguments)
|
||||||
if not isinstance(repaired_any, dict):
|
if not isinstance(repaired, dict):
|
||||||
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
|
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
|
||||||
repaired: dict[str, Any] = repaired_any
|
|
||||||
return repaired
|
return repaired
|
||||||
|
|
||||||
|
|
||||||
def get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||||
"""Get default empty value for a JSON schema type."""
|
"""Get default empty value for a JSON schema type."""
|
||||||
# Handle array of types (e.g., ["string", "null"])
|
# Handle array of types (e.g., ["string", "null"])
|
||||||
if isinstance(type_name, list):
|
if isinstance(type_name, list):
|
||||||
@ -297,7 +311,7 @@ def fill_defaults_from_schema(
|
|||||||
# Create empty object and recursively fill its required fields
|
# Create empty object and recursively fill its required fields
|
||||||
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
|
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
|
||||||
else:
|
else:
|
||||||
result[prop_name] = get_default_value_for_type(prop_type)
|
result[prop_name] = _get_default_value_for_type(prop_type)
|
||||||
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
|
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
|
||||||
# Field exists and is an object, recursively fill nested required fields
|
# Field exists and is an object, recursively fill nested required fields
|
||||||
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
|
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
|
||||||
@ -308,10 +322,10 @@ def fill_defaults_from_schema(
|
|||||||
def _handle_native_json_schema(
|
def _handle_native_json_schema(
|
||||||
provider: str,
|
provider: str,
|
||||||
model_schema: AIModelEntity,
|
model_schema: AIModelEntity,
|
||||||
structured_output_schema: Mapping[str, Any],
|
structured_output_schema: Mapping,
|
||||||
model_parameters: dict[str, Any],
|
model_parameters: dict,
|
||||||
rules: list[ParameterRule],
|
rules: list[ParameterRule],
|
||||||
) -> dict[str, Any]:
|
):
|
||||||
"""
|
"""
|
||||||
Handle structured output for models with native JSON schema support.
|
Handle structured output for models with native JSON schema support.
|
||||||
|
|
||||||
@ -333,7 +347,7 @@ def _handle_native_json_schema(
|
|||||||
return model_parameters
|
return model_parameters
|
||||||
|
|
||||||
|
|
||||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
|
def _set_response_format(model_parameters: dict, rules: list):
|
||||||
"""
|
"""
|
||||||
Set the appropriate response format parameter based on model rules.
|
Set the appropriate response format parameter based on model rules.
|
||||||
|
|
||||||
@ -349,7 +363,7 @@ def _set_response_format(model_parameters: dict[str, Any], rules: list[Parameter
|
|||||||
|
|
||||||
|
|
||||||
def _handle_prompt_based_schema(
|
def _handle_prompt_based_schema(
|
||||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any]
|
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Handle structured output for models without native JSON schema support.
|
Handle structured output for models without native JSON schema support.
|
||||||
@ -386,27 +400,28 @@ def _handle_prompt_based_schema(
|
|||||||
return updated_prompt
|
return updated_prompt
|
||||||
|
|
||||||
|
|
||||||
def _parse_structured_output(result_text: str) -> dict[str, Any]:
|
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||||
|
structured_output: Mapping[str, Any] = {}
|
||||||
|
parsed: Mapping[str, Any] = {}
|
||||||
try:
|
try:
|
||||||
parsed = TypeAdapter(dict[str, Any]).validate_json(result_text)
|
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||||
return parsed
|
if not isinstance(parsed, dict):
|
||||||
|
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = parsed
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
# if the result_text is not a valid json, try to repair it
|
# if the result_text is not a valid json, try to repair it
|
||||||
temp_parsed: Any = json_repair.loads(result_text)
|
temp_parsed = json_repair.loads(result_text)
|
||||||
if isinstance(temp_parsed, list):
|
|
||||||
temp_parsed_list = cast(list[Any], temp_parsed)
|
|
||||||
dict_items: list[dict[str, Any]] = []
|
|
||||||
for item in temp_parsed_list:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
dict_items.append(cast(dict[str, Any], item))
|
|
||||||
temp_parsed = dict_items[0] if dict_items else {}
|
|
||||||
if not isinstance(temp_parsed, dict):
|
if not isinstance(temp_parsed, dict):
|
||||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||||
temp_parsed_dict = cast(dict[str, Any], temp_parsed)
|
if isinstance(temp_parsed, list):
|
||||||
return temp_parsed_dict
|
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||||
|
else:
|
||||||
|
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||||
|
structured_output = cast(dict, temp_parsed)
|
||||||
|
return structured_output
|
||||||
|
|
||||||
|
|
||||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]:
|
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
|
||||||
"""
|
"""
|
||||||
Prepare JSON schema based on model requirements.
|
Prepare JSON schema based on model requirements.
|
||||||
|
|
||||||
@ -418,49 +433,54 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Deep copy to avoid modifying the original schema
|
# Deep copy to avoid modifying the original schema
|
||||||
processed_schema = deepcopy(schema)
|
processed_schema = dict(deepcopy(schema))
|
||||||
processed_schema_dict = dict(processed_schema)
|
|
||||||
|
|
||||||
# Convert boolean types to string types (common requirement)
|
# Convert boolean types to string types (common requirement)
|
||||||
convert_boolean_to_string(processed_schema_dict)
|
convert_boolean_to_string(processed_schema)
|
||||||
|
|
||||||
# Apply model-specific transformations
|
# Apply model-specific transformations
|
||||||
if SpecialModelType.GEMINI in model_schema.model:
|
if SpecialModelType.GEMINI in model_schema.model:
|
||||||
remove_additional_properties(processed_schema_dict)
|
remove_additional_properties(processed_schema)
|
||||||
return processed_schema_dict
|
return processed_schema
|
||||||
if SpecialModelType.OLLAMA in provider:
|
elif SpecialModelType.OLLAMA in provider:
|
||||||
return processed_schema_dict
|
return processed_schema
|
||||||
|
else:
|
||||||
# Default format with name field
|
# Default format with name field
|
||||||
return {"schema": processed_schema_dict, "name": "llm_response"}
|
return {"schema": processed_schema, "name": "llm_response"}
|
||||||
|
|
||||||
|
|
||||||
def remove_additional_properties(schema: dict[str, Any]) -> None:
|
def remove_additional_properties(schema: dict):
|
||||||
"""
|
"""
|
||||||
Remove additionalProperties fields from JSON schema.
|
Remove additionalProperties fields from JSON schema.
|
||||||
Used for models like Gemini that don't support this property.
|
Used for models like Gemini that don't support this property.
|
||||||
|
|
||||||
:param schema: JSON schema to modify in-place
|
:param schema: JSON schema to modify in-place
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
# Remove additionalProperties at current level
|
# Remove additionalProperties at current level
|
||||||
schema.pop("additionalProperties", None)
|
schema.pop("additionalProperties", None)
|
||||||
|
|
||||||
# Process nested structures recursively
|
# Process nested structures recursively
|
||||||
for value in schema.values():
|
for value in schema.values():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
remove_additional_properties(cast(dict[str, Any], value))
|
remove_additional_properties(value)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
for item in cast(list[Any], value):
|
for item in value:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
remove_additional_properties(cast(dict[str, Any], item))
|
remove_additional_properties(item)
|
||||||
|
|
||||||
|
|
||||||
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
def convert_boolean_to_string(schema: dict):
|
||||||
"""
|
"""
|
||||||
Convert boolean type specifications to string in JSON schema.
|
Convert boolean type specifications to string in JSON schema.
|
||||||
|
|
||||||
:param schema: JSON schema to modify in-place
|
:param schema: JSON schema to modify in-place
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(schema, dict):
|
||||||
|
return
|
||||||
|
|
||||||
# Check for boolean type at current level
|
# Check for boolean type at current level
|
||||||
if schema.get("type") == "boolean":
|
if schema.get("type") == "boolean":
|
||||||
schema["type"] = "string"
|
schema["type"] = "string"
|
||||||
@ -468,8 +488,8 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
|||||||
# Process nested dictionaries and lists recursively
|
# Process nested dictionaries and lists recursively
|
||||||
for value in schema.values():
|
for value in schema.values():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
convert_boolean_to_string(cast(dict[str, Any], value))
|
convert_boolean_to_string(value)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
for item in cast(list[Any], value):
|
for item in value:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
convert_boolean_to_string(cast(dict[str, Any], item))
|
convert_boolean_to_string(item)
|
||||||
|
|||||||
@ -1,13 +1,11 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import TypeVar, Union
|
||||||
|
|
||||||
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
|
||||||
pass
|
|
||||||
|
|
||||||
MessageType = TypeVar("MessageType", bound=ToolInvokeMessage)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -89,7 +87,7 @@ def merge_blob_chunks(
|
|||||||
),
|
),
|
||||||
meta=resp.meta,
|
meta=resp.meta,
|
||||||
)
|
)
|
||||||
assert isinstance(merged_message, ToolInvokeMessage)
|
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
|
||||||
yield merged_message # type: ignore
|
yield merged_message # type: ignore
|
||||||
# Clean up the buffer
|
# Clean up the buffer
|
||||||
del files[chunk_id]
|
del files[chunk_id]
|
||||||
|
|||||||
@ -14,8 +14,7 @@ from core.skill.entities import ToolAccessPolicy
|
|||||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||||
from core.tools.signature import sign_tool_file
|
from core.tools.signature import sign_tool_file
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
from core.virtual_environment.__base.helpers import pipeline
|
||||||
from core.virtual_environment.__base.helpers import execute, pipeline
|
|
||||||
|
|
||||||
from ..bash.dify_cli import DifyCliConfig
|
from ..bash.dify_cli import DifyCliConfig
|
||||||
from ..entities import DifyCli
|
from ..entities import DifyCli
|
||||||
@ -120,6 +119,21 @@ class SandboxBashSession:
|
|||||||
return self._bash_tool
|
return self._bash_tool
|
||||||
|
|
||||||
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
|
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
|
||||||
|
"""
|
||||||
|
Collect files from sandbox output directory and save them as ToolFiles.
|
||||||
|
|
||||||
|
Scans the specified output directory in sandbox, downloads each file,
|
||||||
|
saves it as a ToolFile, and returns a list of File objects. The File
|
||||||
|
objects will have valid tool_file_id that can be referenced by subsequent
|
||||||
|
nodes via structured output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Directory path in sandbox to scan for output files.
|
||||||
|
Defaults to "output" (relative to workspace).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of File objects representing the collected files.
|
||||||
|
"""
|
||||||
vm = self._sandbox.vm
|
vm = self._sandbox.vm
|
||||||
collected_files: list[File] = []
|
collected_files: list[File] = []
|
||||||
|
|
||||||
@ -130,6 +144,8 @@ class SandboxBashSession:
|
|||||||
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
|
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
|
||||||
return collected_files
|
return collected_files
|
||||||
|
|
||||||
|
tool_file_manager = ToolFileManager()
|
||||||
|
|
||||||
for file_state in file_states:
|
for file_state in file_states:
|
||||||
# Skip files that are too large
|
# Skip files that are too large
|
||||||
if file_state.size > MAX_OUTPUT_FILE_SIZE:
|
if file_state.size > MAX_OUTPUT_FILE_SIZE:
|
||||||
@ -146,14 +162,47 @@ class SandboxBashSession:
|
|||||||
file_content = vm.download_file(file_state.path)
|
file_content = vm.download_file(file_state.path)
|
||||||
file_binary = file_content.getvalue()
|
file_binary = file_content.getvalue()
|
||||||
|
|
||||||
|
# Determine mime type from extension
|
||||||
filename = os.path.basename(file_state.path)
|
filename = os.path.basename(file_state.path)
|
||||||
file_obj = self._create_tool_file(filename=filename, file_binary=file_binary)
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
if not mime_type:
|
||||||
|
mime_type = "application/octet-stream"
|
||||||
|
|
||||||
|
# Save as ToolFile
|
||||||
|
tool_file = tool_file_manager.create_file_by_raw(
|
||||||
|
user_id=self._user_id,
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_binary=file_binary,
|
||||||
|
mimetype=mime_type,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine file type from mime type
|
||||||
|
file_type = _get_file_type_from_mime(mime_type)
|
||||||
|
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
||||||
|
url = sign_tool_file(tool_file.id, extension)
|
||||||
|
|
||||||
|
# Create File object with tool_file_id as related_id
|
||||||
|
file_obj = File(
|
||||||
|
id=tool_file.id, # Use tool_file_id as the File id for easy reference
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
type=file_type,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
filename=filename,
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size=len(file_binary),
|
||||||
|
related_id=tool_file.id,
|
||||||
|
url=url,
|
||||||
|
storage_key=tool_file.file_key,
|
||||||
|
)
|
||||||
collected_files.append(file_obj)
|
collected_files.append(file_obj)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Collected sandbox output file: %s -> tool_file_id=%s",
|
"Collected sandbox output file: %s -> tool_file_id=%s",
|
||||||
file_state.path,
|
file_state.path,
|
||||||
file_obj.id,
|
tool_file.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@ -167,85 +216,6 @@ class SandboxBashSession:
|
|||||||
)
|
)
|
||||||
return collected_files
|
return collected_files
|
||||||
|
|
||||||
def download_file(self, path: str) -> File:
|
|
||||||
path_kind = self._detect_path_kind(path)
|
|
||||||
if path_kind == "dir":
|
|
||||||
raise ValueError("Directory outputs are not supported")
|
|
||||||
if path_kind != "file":
|
|
||||||
raise ValueError(f"Sandbox file not found: {path}")
|
|
||||||
|
|
||||||
file_content = self._sandbox.vm.download_file(path)
|
|
||||||
file_binary = file_content.getvalue()
|
|
||||||
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
|
|
||||||
raise ValueError(f"Sandbox file exceeds size limit: {path}")
|
|
||||||
|
|
||||||
filename = os.path.basename(path) or "file"
|
|
||||||
return self._create_tool_file(filename=filename, file_binary=file_binary)
|
|
||||||
|
|
||||||
def _detect_path_kind(self, path: str) -> str:
|
|
||||||
script = r"""
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
p = sys.argv[1]
|
|
||||||
if os.path.isdir(p):
|
|
||||||
print("dir")
|
|
||||||
raise SystemExit(0)
|
|
||||||
if os.path.isfile(p):
|
|
||||||
print("file")
|
|
||||||
raise SystemExit(0)
|
|
||||||
print("none")
|
|
||||||
raise SystemExit(2)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = execute(
|
|
||||||
self._sandbox.vm,
|
|
||||||
[
|
|
||||||
"sh",
|
|
||||||
"-c",
|
|
||||||
'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"',
|
|
||||||
script,
|
|
||||||
path,
|
|
||||||
],
|
|
||||||
timeout=10,
|
|
||||||
error_message="Failed to inspect sandbox path",
|
|
||||||
)
|
|
||||||
except CommandExecutionError as exc:
|
|
||||||
raise ValueError(str(exc)) from exc
|
|
||||||
return result.stdout.decode("utf-8", errors="replace").strip()
|
|
||||||
|
|
||||||
def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File:
|
|
||||||
mime_type, _ = mimetypes.guess_type(filename)
|
|
||||||
if not mime_type:
|
|
||||||
mime_type = "application/octet-stream"
|
|
||||||
|
|
||||||
tool_file = ToolFileManager().create_file_by_raw(
|
|
||||||
user_id=self._user_id,
|
|
||||||
tenant_id=self._tenant_id,
|
|
||||||
conversation_id=None,
|
|
||||||
file_binary=file_binary,
|
|
||||||
mimetype=mime_type,
|
|
||||||
filename=filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
file_type = _get_file_type_from_mime(mime_type)
|
|
||||||
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
|
||||||
url = sign_tool_file(tool_file.id, extension)
|
|
||||||
|
|
||||||
return File(
|
|
||||||
id=tool_file.id,
|
|
||||||
tenant_id=self._tenant_id,
|
|
||||||
type=file_type,
|
|
||||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
|
||||||
filename=filename,
|
|
||||||
extension=extension,
|
|
||||||
mime_type=mime_type,
|
|
||||||
size=len(file_binary),
|
|
||||||
related_id=tool_file.id,
|
|
||||||
url=url,
|
|
||||||
storage_key=tool_file.file_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_file_type_from_mime(mime_type: str) -> FileType:
|
def _get_file_type_from_mime(mime_type: str) -> FileType:
|
||||||
"""Determine FileType from mime type."""
|
"""Determine FileType from mime type."""
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class Tool(ABC):
|
|||||||
tool_parameters.update(self.runtime.runtime_parameters)
|
tool_parameters.update(self.runtime.runtime_parameters)
|
||||||
|
|
||||||
# try parse tool parameters into the correct type
|
# try parse tool parameters into the correct type
|
||||||
tool_parameters = self.transform_tool_parameters_type(tool_parameters)
|
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||||
|
|
||||||
result = self._invoke(
|
result = self._invoke(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@ -82,7 +82,7 @@ class Tool(ABC):
|
|||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Transform tool parameters type
|
Transform tool parameters type
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +0,0 @@
|
|||||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
|
|
||||||
<rect x="3" y="3" width="18" height="18" rx="4" stroke="#2F2F2F" stroke-width="1.5"/>
|
|
||||||
<path d="M7 12h10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
|
|
||||||
<path d="M12 7v10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 332 B |
@ -1,8 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
|
||||||
|
|
||||||
|
|
||||||
class AgentOutputProvider(BuiltinToolProviderController):
|
|
||||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
|
||||||
pass
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
identity:
|
|
||||||
author: Dify
|
|
||||||
name: agent_output
|
|
||||||
label:
|
|
||||||
en_US: Agent Output
|
|
||||||
description:
|
|
||||||
en_US: Internal tools for agent output control.
|
|
||||||
icon: icon.svg
|
|
||||||
tags:
|
|
||||||
- utilities
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
|
|
||||||
|
|
||||||
class FinalOutputAnswerTool(BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
app_id: str | None = None,
|
|
||||||
message_id: str | None = None,
|
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
|
||||||
yield self.create_text_message("Final answer recorded.")
|
|
||||||
@ -1,18 +0,0 @@
|
|||||||
identity:
|
|
||||||
name: final_output_answer
|
|
||||||
author: Dify
|
|
||||||
label:
|
|
||||||
en_US: Final Output Answer
|
|
||||||
description:
|
|
||||||
human:
|
|
||||||
en_US: Internal tool to deliver the final answer.
|
|
||||||
llm: Use this tool when you are ready to provide the final answer.
|
|
||||||
parameters:
|
|
||||||
- name: text
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
label:
|
|
||||||
en_US: Text
|
|
||||||
human_description:
|
|
||||||
en_US: Final answer text.
|
|
||||||
form: llm
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
|
|
||||||
|
|
||||||
class FinalStructuredOutputTool(BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
app_id: str | None = None,
|
|
||||||
message_id: str | None = None,
|
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
|
||||||
yield self.create_text_message("Structured output recorded.")
|
|
||||||
@ -1,18 +0,0 @@
|
|||||||
identity:
|
|
||||||
name: final_structured_output
|
|
||||||
author: Dify
|
|
||||||
label:
|
|
||||||
en_US: Final Structured Output
|
|
||||||
description:
|
|
||||||
human:
|
|
||||||
en_US: Internal tool to deliver structured output.
|
|
||||||
llm: Use this tool to provide structured output data.
|
|
||||||
parameters:
|
|
||||||
- name: data
|
|
||||||
type: object
|
|
||||||
required: true
|
|
||||||
label:
|
|
||||||
en_US: Data
|
|
||||||
human_description:
|
|
||||||
en_US: Structured output data.
|
|
||||||
form: llm
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
|
|
||||||
|
|
||||||
class IllegalOutputTool(BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
app_id: str | None = None,
|
|
||||||
message_id: str | None = None,
|
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
|
||||||
message = (
|
|
||||||
"Protocol violation: do not output plain text. "
|
|
||||||
"Call an output tool and finish with the configured terminal tool."
|
|
||||||
)
|
|
||||||
yield self.create_text_message(message)
|
|
||||||
@ -1,18 +0,0 @@
|
|||||||
identity:
|
|
||||||
name: illegal_output
|
|
||||||
author: Dify
|
|
||||||
label:
|
|
||||||
en_US: Illegal Output
|
|
||||||
description:
|
|
||||||
human:
|
|
||||||
en_US: Internal tool for output protocol violations.
|
|
||||||
llm: Use this tool to correct output protocol violations.
|
|
||||||
parameters:
|
|
||||||
- name: raw
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
label:
|
|
||||||
en_US: Raw Output
|
|
||||||
human_description:
|
|
||||||
en_US: Raw model output that violated the protocol.
|
|
||||||
form: llm
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
|
||||||
|
|
||||||
|
|
||||||
class OutputTextTool(BuiltinTool):
|
|
||||||
def _invoke(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
conversation_id: str | None = None,
|
|
||||||
app_id: str | None = None,
|
|
||||||
message_id: str | None = None,
|
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
|
||||||
yield self.create_text_message("Output recorded.")
|
|
||||||
@ -1,18 +0,0 @@
|
|||||||
identity:
|
|
||||||
name: output_text
|
|
||||||
author: Dify
|
|
||||||
label:
|
|
||||||
en_US: Output Text
|
|
||||||
description:
|
|
||||||
human:
|
|
||||||
en_US: Internal tool to store intermediate text output.
|
|
||||||
llm: Use this tool to emit non-final text output.
|
|
||||||
parameters:
|
|
||||||
- name: text
|
|
||||||
type: string
|
|
||||||
required: true
|
|
||||||
label:
|
|
||||||
en_US: Text
|
|
||||||
human_description:
|
|
||||||
en_US: Output text.
|
|
||||||
form: llm
|
|
||||||
@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
@ -33,8 +31,9 @@ from services.enterprise.plugin_manager_service import PluginCredentialType
|
|||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.agent.entities import AgentToolEntity
|
|
||||||
from core.workflow.nodes.tool.entities import ToolEntity
|
from core.workflow.nodes.tool.entities import ToolEntity
|
||||||
|
|
||||||
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
@ -67,8 +66,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
INTERNAL_BUILTIN_TOOL_PROVIDERS = {"agent_output"}
|
|
||||||
|
|
||||||
|
|
||||||
class ApiProviderControllerItem(TypedDict):
|
class ApiProviderControllerItem(TypedDict):
|
||||||
provider: ApiToolProvider
|
provider: ApiToolProvider
|
||||||
@ -364,7 +361,7 @@ class ToolManager:
|
|||||||
app_id: str,
|
app_id: str,
|
||||||
agent_tool: AgentToolEntity,
|
agent_tool: AgentToolEntity,
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
variable_pool: Optional[VariablePool] = None,
|
variable_pool: Optional["VariablePool"] = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get the agent tool runtime
|
get the agent tool runtime
|
||||||
@ -404,9 +401,9 @@ class ToolManager:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_id: str,
|
app_id: str,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
workflow_tool: ToolEntity,
|
workflow_tool: "ToolEntity",
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
variable_pool: Optional[VariablePool] = None,
|
variable_pool: Optional["VariablePool"] = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get the workflow tool runtime
|
get the workflow tool runtime
|
||||||
@ -594,10 +591,6 @@ class ToolManager:
|
|||||||
cls._hardcoded_providers = {}
|
cls._hardcoded_providers = {}
|
||||||
cls._builtin_providers_loaded = False
|
cls._builtin_providers_loaded = False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_internal_builtin_provider(cls, provider_name: str) -> bool:
|
|
||||||
return provider_name in INTERNAL_BUILTIN_TOOL_PROVIDERS
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
||||||
"""
|
"""
|
||||||
@ -635,9 +628,9 @@ class ToolManager:
|
|||||||
# MySQL: Use window function to achieve same result
|
# MySQL: Use window function to achieve same result
|
||||||
sql = """
|
sql = """
|
||||||
SELECT id FROM (
|
SELECT id FROM (
|
||||||
SELECT id,
|
SELECT id,
|
||||||
ROW_NUMBER() OVER (
|
ROW_NUMBER() OVER (
|
||||||
PARTITION BY tenant_id, provider
|
PARTITION BY tenant_id, provider
|
||||||
ORDER BY is_default DESC, created_at DESC
|
ORDER BY is_default DESC, created_at DESC
|
||||||
) as rn
|
) as rn
|
||||||
FROM tool_builtin_providers
|
FROM tool_builtin_providers
|
||||||
@ -674,8 +667,6 @@ class ToolManager:
|
|||||||
|
|
||||||
# append builtin providers
|
# append builtin providers
|
||||||
for provider in builtin_providers:
|
for provider in builtin_providers:
|
||||||
if cls.is_internal_builtin_provider(provider.entity.identity.name):
|
|
||||||
continue
|
|
||||||
# handle include, exclude
|
# handle include, exclude
|
||||||
if is_filtered(
|
if is_filtered(
|
||||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||||
@ -1019,7 +1010,7 @@ class ToolManager:
|
|||||||
def _convert_tool_parameters_type(
|
def _convert_tool_parameters_type(
|
||||||
cls,
|
cls,
|
||||||
parameters: list[ToolParameter],
|
parameters: list[ToolParameter],
|
||||||
variable_pool: Optional[VariablePool],
|
variable_pool: Optional["VariablePool"],
|
||||||
tool_configurations: dict[str, Any],
|
tool_configurations: dict[str, Any],
|
||||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
|||||||
@ -577,12 +577,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
structured_output_schema=None,
|
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||||
|
structured_output=None,
|
||||||
file_saver=self._llm_file_saver,
|
file_saver=self._llm_file_saver,
|
||||||
file_outputs=self._file_outputs,
|
file_outputs=self._file_outputs,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in generator:
|
for event in generator:
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid
|
|||||||
from core.agent.entities import AgentLog, AgentResult
|
from core.agent.entities import AgentLog, AgentResult
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||||
from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from core.workflow.entities import ToolCall, ToolCallResult
|
from core.workflow.entities import ToolCall, ToolCallResult
|
||||||
@ -156,9 +156,6 @@ class LLMGenerationData(BaseModel):
|
|||||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||||
structured_output: LLMStructuredOutput | None = Field(
|
|
||||||
default=None, description="Structured output from tool-only agent runs"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ThinkTagStreamParser:
|
class ThinkTagStreamParser:
|
||||||
@ -287,7 +284,6 @@ class AggregatedResult(BaseModel):
|
|||||||
files: list[File] = Field(default_factory=list)
|
files: list[File] = Field(default_factory=list)
|
||||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||||
finish_reason: str | None = None
|
finish_reason: str | None = None
|
||||||
structured_output: LLMStructuredOutput | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class AgentContext(BaseModel):
|
class AgentContext(BaseModel):
|
||||||
@ -387,7 +383,7 @@ class LLMNodeData(BaseNodeData):
|
|||||||
Strategy for handling model reasoning output.
|
Strategy for handling model reasoning output.
|
||||||
|
|
||||||
separated: Return clean text (without <think> tags) + reasoning_content field.
|
separated: Return clean text (without <think> tags) + reasoning_content field.
|
||||||
Recommended for new workflows. Enables safe downstream parsing and
|
Recommended for new workflows. Enables safe downstream parsing and
|
||||||
workflow variable access: {{#node_id.reasoning_content#}}
|
workflow variable access: {{#node_id.reasoning_content#}}
|
||||||
|
|
||||||
tagged : Return original text (with <think> tags) + reasoning_content field.
|
tagged : Return original text (with <think> tags) + reasoning_content field.
|
||||||
|
|||||||
@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str:
|
|||||||
"""
|
"""
|
||||||
Build a text description of generated files for inclusion in context.
|
Build a text description of generated files for inclusion in context.
|
||||||
|
|
||||||
The description includes file_id for context; structured output file paths
|
The description includes file_id which can be used by subsequent nodes
|
||||||
are only supported in sandbox mode.
|
to reference the files via structured output.
|
||||||
"""
|
"""
|
||||||
if not files:
|
if not files:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@ -13,24 +13,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||||
from core.agent.output_tools import (
|
|
||||||
FINAL_OUTPUT_TOOL,
|
|
||||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
|
||||||
OUTPUT_TEXT_TOOL,
|
|
||||||
build_agent_output_tools,
|
|
||||||
)
|
|
||||||
from core.agent.patterns import StrategyFactory
|
from core.agent.patterns import StrategyFactory
|
||||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.app_assets.constants import AppAssetsAttrs
|
from core.app_assets.constants import AppAssetsAttrs
|
||||||
from core.file import FileTransferMethod, FileType, file_manager
|
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.output_parser.file_ref import (
|
|
||||||
adapt_schema_for_sandbox_file_paths,
|
|
||||||
convert_sandbox_file_paths_in_output,
|
|
||||||
detect_file_path_fields,
|
|
||||||
)
|
|
||||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||||
from core.memory.base import BaseMemory
|
from core.memory.base import BaseMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
@ -73,7 +62,6 @@ from core.skill.entities.skill_document import SkillDocument
|
|||||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||||
from core.skill.skill_compiler import SkillCompiler
|
from core.skill.skill_compiler import SkillCompiler
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
|
||||||
from core.tools.signature import sign_upload_file
|
from core.tools.signature import sign_upload_file
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
@ -197,11 +185,12 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
node_inputs: dict[str, Any] = {}
|
node_inputs: dict[str, Any] = {}
|
||||||
process_data: dict[str, Any] = {}
|
process_data: dict[str, Any] = {}
|
||||||
usage: LLMUsage = LLMUsage.empty_usage()
|
clean_text = ""
|
||||||
finish_reason: str | None = None
|
usage = LLMUsage.empty_usage()
|
||||||
reasoning_content: str = "" # Initialize as empty string for consistency
|
finish_reason = None
|
||||||
|
reasoning_content = "" # Initialize as empty string for consistency
|
||||||
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
|
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
|
||||||
variable_pool: VariablePool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse prompt template to separate static messages and context references
|
# Parse prompt template to separate static messages and context references
|
||||||
@ -261,9 +250,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
query: str | None = None
|
query: str | None = None
|
||||||
memory_config = self.node_data.memory
|
if self.node_data.memory:
|
||||||
if memory_config:
|
query = self.node_data.memory.query_prompt_template
|
||||||
query = memory_config.query_prompt_template
|
|
||||||
if not query and (
|
if not query and (
|
||||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||||
):
|
):
|
||||||
@ -305,23 +293,9 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
sandbox=self.graph_runtime_state.sandbox,
|
sandbox=self.graph_runtime_state.sandbox,
|
||||||
)
|
)
|
||||||
|
|
||||||
structured_output_schema: Mapping[str, Any] | None
|
# Variables for outputs
|
||||||
structured_output_file_paths: list[str] = []
|
generation_data: LLMGenerationData | None = None
|
||||||
|
structured_output: LLMStructuredOutput | None = None
|
||||||
if self.node_data.structured_output_enabled:
|
|
||||||
if not self.node_data.structured_output:
|
|
||||||
raise ValueError("structured_output_enabled is True but structured_output is not set")
|
|
||||||
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
|
|
||||||
if self.node_data.computer_use:
|
|
||||||
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
|
|
||||||
raw_schema
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if detect_file_path_fields(raw_schema):
|
|
||||||
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
|
|
||||||
structured_output_schema = raw_schema
|
|
||||||
else:
|
|
||||||
structured_output_schema = None
|
|
||||||
|
|
||||||
if self.node_data.computer_use:
|
if self.node_data.computer_use:
|
||||||
sandbox = self.graph_runtime_state.sandbox
|
sandbox = self.graph_runtime_state.sandbox
|
||||||
@ -335,10 +309,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
tool_dependencies=tool_dependencies,
|
tool_dependencies=tool_dependencies,
|
||||||
structured_output_schema=structured_output_schema,
|
|
||||||
structured_output_file_paths=structured_output_file_paths,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.tool_call_enabled:
|
elif self.tool_call_enabled:
|
||||||
generator = self._invoke_llm_with_tools(
|
generator = self._invoke_llm_with_tools(
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
@ -348,7 +319,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
node_inputs=node_inputs,
|
node_inputs=node_inputs,
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
structured_output_schema=structured_output_schema,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use traditional LLM invocation
|
# Use traditional LLM invocation
|
||||||
@ -358,7 +328,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
structured_output_schema=structured_output_schema,
|
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||||
|
structured_output=self._node_data.structured_output,
|
||||||
file_saver=self._llm_file_saver,
|
file_saver=self._llm_file_saver,
|
||||||
file_outputs=self._file_outputs,
|
file_outputs=self._file_outputs,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -384,8 +355,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
usage = generation_data.usage
|
usage = generation_data.usage
|
||||||
finish_reason = generation_data.finish_reason
|
finish_reason = generation_data.finish_reason
|
||||||
if generation_data.structured_output:
|
|
||||||
structured_output = generation_data.structured_output
|
|
||||||
|
|
||||||
# Unified process_data building
|
# Unified process_data building
|
||||||
process_data = {
|
process_data = {
|
||||||
@ -409,7 +378,16 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
if tool.enabled
|
if tool.enabled
|
||||||
]
|
]
|
||||||
|
|
||||||
# Build generation field and determine files_to_output first
|
# Unified outputs building
|
||||||
|
outputs = {
|
||||||
|
"text": clean_text,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
"usage": jsonable_encoder(usage),
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build generation field
|
||||||
if generation_data:
|
if generation_data:
|
||||||
# Use generation_data from tool invocation (supports multi-turn)
|
# Use generation_data from tool invocation (supports multi-turn)
|
||||||
generation = {
|
generation = {
|
||||||
@ -437,15 +415,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
}
|
}
|
||||||
files_to_output = self._file_outputs
|
files_to_output = self._file_outputs
|
||||||
|
|
||||||
# Unified outputs building (files passed to context for subsequent node reference)
|
|
||||||
outputs = {
|
|
||||||
"text": clean_text,
|
|
||||||
"reasoning_content": reasoning_content,
|
|
||||||
"usage": jsonable_encoder(usage),
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data, files=files_to_output),
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs["generation"] = generation
|
outputs["generation"] = generation
|
||||||
if files_to_output:
|
if files_to_output:
|
||||||
outputs["files"] = ArrayFileSegment(value=files_to_output)
|
outputs["files"] = ArrayFileSegment(value=files_to_output)
|
||||||
@ -524,7 +493,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
prompt_messages: Sequence[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
stop: Sequence[str] | None = None,
|
stop: Sequence[str] | None = None,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
structured_output_schema: Mapping[str, Any] | None,
|
structured_output_enabled: bool,
|
||||||
|
structured_output: Mapping[str, Any] | None = None,
|
||||||
file_saver: LLMFileSaver,
|
file_saver: LLMFileSaver,
|
||||||
file_outputs: list[File],
|
file_outputs: list[File],
|
||||||
node_id: str,
|
node_id: str,
|
||||||
@ -538,7 +508,10 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
if not model_schema:
|
if not model_schema:
|
||||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||||
|
|
||||||
if structured_output_schema:
|
if structured_output_enabled:
|
||||||
|
output_schema = LLMNode.fetch_structured_output_schema(
|
||||||
|
structured_output=structured_output or {},
|
||||||
|
)
|
||||||
request_start_time = time.perf_counter()
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
invoke_result = invoke_llm_with_structured_output(
|
invoke_result = invoke_llm_with_structured_output(
|
||||||
@ -546,12 +519,12 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
json_schema=structured_output_schema,
|
json_schema=output_schema,
|
||||||
model_parameters=node_data_model.completion_params,
|
model_parameters=node_data_model.completion_params,
|
||||||
stop=list(stop or []),
|
stop=list(stop or []),
|
||||||
user=user_id,
|
user=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
request_start_time = time.perf_counter()
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
@ -1288,16 +1261,18 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
# Insert histories into the prompt
|
# Insert histories into the prompt
|
||||||
prompt_content = prompt_messages[0].content
|
prompt_content = prompt_messages[0].content
|
||||||
# For issue #11247 - Check if prompt content is a string or a list
|
# For issue #11247 - Check if prompt content is a string or a list
|
||||||
if isinstance(prompt_content, str):
|
prompt_content_type = type(prompt_content)
|
||||||
|
if prompt_content_type == str:
|
||||||
prompt_content = str(prompt_content)
|
prompt_content = str(prompt_content)
|
||||||
if "#histories#" in prompt_content:
|
if "#histories#" in prompt_content:
|
||||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||||
else:
|
else:
|
||||||
prompt_content = memory_text + "\n" + prompt_content
|
prompt_content = memory_text + "\n" + prompt_content
|
||||||
prompt_messages[0].content = prompt_content
|
prompt_messages[0].content = prompt_content
|
||||||
elif isinstance(prompt_content, list):
|
elif prompt_content_type == list:
|
||||||
|
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||||
for content_item in prompt_content:
|
for content_item in prompt_content:
|
||||||
if isinstance(content_item, TextPromptMessageContent):
|
if content_item.type == PromptMessageContentType.TEXT:
|
||||||
if "#histories#" in content_item.data:
|
if "#histories#" in content_item.data:
|
||||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||||
else:
|
else:
|
||||||
@ -1307,12 +1282,13 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
|
|
||||||
# Add current query to the prompt message
|
# Add current query to the prompt message
|
||||||
if sys_query:
|
if sys_query:
|
||||||
if isinstance(prompt_content, str):
|
if prompt_content_type == str:
|
||||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||||
prompt_messages[0].content = prompt_content
|
prompt_messages[0].content = prompt_content
|
||||||
elif isinstance(prompt_content, list):
|
elif prompt_content_type == list:
|
||||||
|
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||||
for content_item in prompt_content:
|
for content_item in prompt_content:
|
||||||
if isinstance(content_item, TextPromptMessageContent):
|
if content_item.type == PromptMessageContentType.TEXT:
|
||||||
content_item.data = sys_query + "\n" + content_item.data
|
content_item.data = sys_query + "\n" + content_item.data
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid prompt content type")
|
raise ValueError("Invalid prompt content type")
|
||||||
@ -1438,11 +1414,9 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
if isinstance(item, PromptMessageContext):
|
if isinstance(item, PromptMessageContext):
|
||||||
if len(item.value_selector) >= 2:
|
if len(item.value_selector) >= 2:
|
||||||
prompt_context_selectors.append(item.value_selector)
|
prompt_context_selectors.append(item.value_selector)
|
||||||
elif isinstance(item, LLMNodeChatModelMessage):
|
elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||||
variable_template_parser = VariableTemplateParser(template=item.text)
|
variable_template_parser = VariableTemplateParser(template=item.text)
|
||||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||||
else:
|
|
||||||
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
|
||||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||||
if prompt_template.edition_type != "jinja2":
|
if prompt_template.edition_type != "jinja2":
|
||||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||||
@ -1478,14 +1452,13 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
if typed_node_data.prompt_config:
|
if typed_node_data.prompt_config:
|
||||||
enable_jinja = False
|
enable_jinja = False
|
||||||
|
|
||||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
if isinstance(prompt_template, list):
|
||||||
if prompt_template.edition_type == "jinja2":
|
for item in prompt_template:
|
||||||
enable_jinja = True
|
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||||
else:
|
|
||||||
for prompt in prompt_template:
|
|
||||||
if prompt.edition_type == "jinja2":
|
|
||||||
enable_jinja = True
|
enable_jinja = True
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
enable_jinja = True
|
||||||
|
|
||||||
if enable_jinja:
|
if enable_jinja:
|
||||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||||
@ -1898,7 +1871,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
node_inputs: dict[str, Any],
|
node_inputs: dict[str, Any],
|
||||||
process_data: dict[str, Any],
|
process_data: dict[str, Any],
|
||||||
structured_output_schema: Mapping[str, Any] | None,
|
|
||||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||||
"""Invoke LLM with tools support (from Agent V2).
|
"""Invoke LLM with tools support (from Agent V2).
|
||||||
|
|
||||||
@ -1915,16 +1887,12 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
|
|
||||||
# Use factory to create appropriate strategy
|
# Use factory to create appropriate strategy
|
||||||
strategy = StrategyFactory.create_strategy(
|
strategy = StrategyFactory.create_strategy(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
invoke_from=self.invoke_from,
|
|
||||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
|
||||||
model_features=model_features,
|
model_features=model_features,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
tools=tool_instances,
|
tools=tool_instances,
|
||||||
files=prompt_files,
|
files=prompt_files,
|
||||||
max_iterations=self._node_data.max_iterations or 10,
|
max_iterations=self._node_data.max_iterations or 10,
|
||||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||||
structured_output_schema=structured_output_schema,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run strategy
|
# Run strategy
|
||||||
@ -1932,6 +1900,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
prompt_messages=list(prompt_messages),
|
prompt_messages=list(prompt_messages),
|
||||||
model_parameters=self._node_data.model.completion_params,
|
model_parameters=self._node_data.model.completion_params,
|
||||||
stop=list(stop or []),
|
stop=list(stop or []),
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield from self._process_tool_outputs(outputs)
|
result = yield from self._process_tool_outputs(outputs)
|
||||||
@ -1945,12 +1914,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
stop: Sequence[str] | None,
|
stop: Sequence[str] | None,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
tool_dependencies: ToolDependencies | None,
|
tool_dependencies: ToolDependencies | None,
|
||||||
structured_output_schema: Mapping[str, Any] | None,
|
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||||
structured_output_file_paths: Sequence[str] | None,
|
|
||||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]:
|
|
||||||
result: LLMGenerationData | None = None
|
result: LLMGenerationData | None = None
|
||||||
sandbox_output_files: list[File] = []
|
|
||||||
structured_output_files: list[File] = []
|
|
||||||
|
|
||||||
# FIXME(Mairuis): Async processing for bash session.
|
# FIXME(Mairuis): Async processing for bash session.
|
||||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
||||||
@ -1958,9 +1923,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
model_features = self._get_model_features(model_instance)
|
model_features = self._get_model_features(model_instance)
|
||||||
|
|
||||||
strategy = StrategyFactory.create_strategy(
|
strategy = StrategyFactory.create_strategy(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
invoke_from=self.invoke_from,
|
|
||||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
|
||||||
model_features=model_features,
|
model_features=model_features,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
tools=[session.bash_tool],
|
tools=[session.bash_tool],
|
||||||
@ -1968,55 +1930,20 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
max_iterations=self._node_data.max_iterations or 100,
|
max_iterations=self._node_data.max_iterations or 100,
|
||||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||||
structured_output_schema=structured_output_schema,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = strategy.run(
|
outputs = strategy.run(
|
||||||
prompt_messages=list(prompt_messages),
|
prompt_messages=list(prompt_messages),
|
||||||
model_parameters=self._node_data.model.completion_params,
|
model_parameters=self._node_data.model.completion_params,
|
||||||
stop=list(stop or []),
|
stop=list(stop or []),
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield from self._process_tool_outputs(outputs)
|
result = yield from self._process_tool_outputs(outputs)
|
||||||
|
|
||||||
if result and result.structured_output and structured_output_file_paths:
|
|
||||||
structured_output_payload = result.structured_output.structured_output or {}
|
|
||||||
try:
|
|
||||||
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
|
|
||||||
output=structured_output_payload,
|
|
||||||
file_path_fields=structured_output_file_paths,
|
|
||||||
file_resolver=session.download_file,
|
|
||||||
)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise LLMNodeError(str(exc)) from exc
|
|
||||||
result = result.model_copy(
|
|
||||||
update={"structured_output": LLMStructuredOutput(structured_output=converted_output)}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect output files from sandbox before session ends
|
|
||||||
# Files are saved as ToolFiles with valid tool_file_id for later reference
|
|
||||||
sandbox_output_files = session.collect_output_files()
|
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise LLMNodeError("SandboxSession exited unexpectedly")
|
raise LLMNodeError("SandboxSession exited unexpectedly")
|
||||||
|
|
||||||
structured_output = result.structured_output
|
|
||||||
if structured_output is not None:
|
|
||||||
yield structured_output
|
|
||||||
|
|
||||||
# Merge sandbox output files into result
|
|
||||||
if sandbox_output_files or structured_output_files:
|
|
||||||
result = LLMGenerationData(
|
|
||||||
text=result.text,
|
|
||||||
reasoning_contents=result.reasoning_contents,
|
|
||||||
tool_calls=result.tool_calls,
|
|
||||||
sequence=result.sequence,
|
|
||||||
usage=result.usage,
|
|
||||||
finish_reason=result.finish_reason,
|
|
||||||
files=result.files + sandbox_output_files + structured_output_files,
|
|
||||||
trace=result.trace,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
||||||
@ -2084,20 +2011,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
structured_output_schema = None
|
|
||||||
if self._node_data.structured_output_enabled:
|
|
||||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
|
||||||
structured_output=self._node_data.structured_output or {},
|
|
||||||
)
|
|
||||||
tool_instances.extend(
|
|
||||||
build_agent_output_tools(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
invoke_from=self.invoke_from,
|
|
||||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
|
||||||
structured_output_schema=structured_output_schema,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return tool_instances
|
return tool_instances
|
||||||
|
|
||||||
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
|
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
|
||||||
@ -2261,45 +2174,18 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
# Add tool call to pending list for model segment
|
# Add tool call to pending list for model segment
|
||||||
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
|
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
|
||||||
|
|
||||||
output_tool_names = {OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL}
|
yield ToolCallChunkEvent(
|
||||||
|
selector=[self._node_id, "generation", "tool_calls"],
|
||||||
if tool_name not in output_tool_names:
|
chunk=tool_arguments,
|
||||||
yield ToolCallChunkEvent(
|
tool_call=ToolCall(
|
||||||
selector=[self._node_id, "generation", "tool_calls"],
|
id=tool_call_id,
|
||||||
chunk=tool_arguments,
|
name=tool_name,
|
||||||
tool_call=ToolCall(
|
arguments=tool_arguments,
|
||||||
id=tool_call_id,
|
icon=tool_icon,
|
||||||
name=tool_name,
|
icon_dark=tool_icon_dark,
|
||||||
arguments=tool_arguments,
|
),
|
||||||
icon=tool_icon,
|
is_final=False,
|
||||||
icon_dark=tool_icon_dark,
|
)
|
||||||
),
|
|
||||||
is_final=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if tool_name in output_tool_names:
|
|
||||||
content = ""
|
|
||||||
if tool_name in (OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL):
|
|
||||||
content = payload.tool_args["text"]
|
|
||||||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
|
||||||
raw_content = json.dumps(
|
|
||||||
payload.tool_args["data"],
|
|
||||||
ensure_ascii=False,
|
|
||||||
indent=2
|
|
||||||
)
|
|
||||||
content = f"```json\n{raw_content}\n```"
|
|
||||||
|
|
||||||
if content:
|
|
||||||
yield StreamChunkEvent(
|
|
||||||
selector=[self._node_id, "text"],
|
|
||||||
chunk=content,
|
|
||||||
is_final=False,
|
|
||||||
)
|
|
||||||
yield StreamChunkEvent(
|
|
||||||
selector=[self._node_id, "generation", "content"],
|
|
||||||
chunk=content,
|
|
||||||
is_final=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
||||||
tool_name = payload.tool_name
|
tool_name = payload.tool_name
|
||||||
@ -2543,7 +2429,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
content_position = 0
|
content_position = 0
|
||||||
tool_call_seen_index: dict[str, int] = {}
|
tool_call_seen_index: dict[str, int] = {}
|
||||||
for trace_segment in trace_state.trace_segments:
|
for trace_segment in trace_state.trace_segments:
|
||||||
# FIXME: These if will never happen
|
|
||||||
if trace_segment.type == "thought":
|
if trace_segment.type == "thought":
|
||||||
sequence.append({"type": "reasoning", "index": reasoning_index})
|
sequence.append({"type": "reasoning", "index": reasoning_index})
|
||||||
reasoning_index += 1
|
reasoning_index += 1
|
||||||
@ -2586,15 +2471,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
|
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
|
||||||
)
|
)
|
||||||
|
|
||||||
text_content: str
|
|
||||||
if aggregate.text:
|
|
||||||
text_content = aggregate.text
|
|
||||||
elif aggregate.structured_output:
|
|
||||||
text_content = json.dumps(aggregate.structured_output.structured_output)
|
|
||||||
else:
|
|
||||||
raise ValueError("Aggregate must have either text or structured output.")
|
|
||||||
return LLMGenerationData(
|
return LLMGenerationData(
|
||||||
text=text_content,
|
text=aggregate.text,
|
||||||
reasoning_contents=buffers.reasoning_per_turn,
|
reasoning_contents=buffers.reasoning_per_turn,
|
||||||
tool_calls=tool_calls_for_generation,
|
tool_calls=tool_calls_for_generation,
|
||||||
sequence=sequence,
|
sequence=sequence,
|
||||||
@ -2602,7 +2480,6 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
finish_reason=aggregate.finish_reason,
|
finish_reason=aggregate.finish_reason,
|
||||||
files=aggregate.files,
|
files=aggregate.files,
|
||||||
trace=trace_state.trace_segments,
|
trace=trace_state.trace_segments,
|
||||||
structured_output=aggregate.structured_output,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_tool_outputs(
|
def _process_tool_outputs(
|
||||||
@ -2613,33 +2490,22 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
state = ToolOutputState()
|
state = ToolOutputState()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
for output in outputs:
|
||||||
output = next(outputs)
|
|
||||||
if isinstance(output, AgentLog):
|
if isinstance(output, AgentLog):
|
||||||
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
||||||
else:
|
else:
|
||||||
continue
|
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
||||||
except StopIteration as exception:
|
except StopIteration as exception:
|
||||||
if not isinstance(exception.value, AgentResult):
|
if isinstance(getattr(exception, "value", None), AgentResult):
|
||||||
raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception
|
state.agent.agent_result = exception.value
|
||||||
state.agent.agent_result = exception.value
|
|
||||||
agent_result = state.agent.agent_result
|
|
||||||
if not agent_result:
|
|
||||||
raise ValueError("No agent result found in tool outputs")
|
|
||||||
output_payload = agent_result.output
|
|
||||||
if isinstance(output_payload, dict):
|
|
||||||
state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload)
|
|
||||||
state.aggregate.text = json.dumps(output_payload)
|
|
||||||
elif isinstance(output_payload, str):
|
|
||||||
state.aggregate.text = output_payload
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected output type: {type(output_payload)}")
|
|
||||||
|
|
||||||
state.aggregate.files = state.agent.agent_result.files
|
if state.agent.agent_result:
|
||||||
if state.agent.agent_result.usage:
|
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
||||||
state.aggregate.usage = state.agent.agent_result.usage
|
state.aggregate.files = state.agent.agent_result.files
|
||||||
if state.agent.agent_result.finish_reason:
|
if state.agent.agent_result.usage:
|
||||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
state.aggregate.usage = state.agent.agent_result.usage
|
||||||
|
if state.agent.agent_result.finish_reason:
|
||||||
|
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||||
|
|
||||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||||
yield from self._close_streams()
|
yield from self._close_streams()
|
||||||
|
|||||||
@ -156,7 +156,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
structured_output_schema=None,
|
structured_output_enabled=False,
|
||||||
|
structured_output=None,
|
||||||
file_saver=self._llm_file_saver,
|
file_saver=self._llm_file_saver,
|
||||||
file_outputs=self._file_outputs,
|
file_outputs=self._file_outputs,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
|
|||||||
@ -91,8 +91,6 @@ class BuiltinToolManageService:
|
|||||||
:return: the list of tools
|
:return: the list of tools
|
||||||
"""
|
"""
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
|
|
||||||
return []
|
|
||||||
tools = provider_controller.get_tools()
|
tools = provider_controller.get_tools()
|
||||||
|
|
||||||
result: list[ToolApiEntity] = []
|
result: list[ToolApiEntity] = []
|
||||||
@ -543,8 +541,6 @@ class BuiltinToolManageService:
|
|||||||
|
|
||||||
for provider_controller in provider_controllers:
|
for provider_controller in provider_controllers:
|
||||||
try:
|
try:
|
||||||
if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
|
|
||||||
continue
|
|
||||||
# handle include, exclude
|
# handle include, exclude
|
||||||
if is_filtered(
|
if is_filtered(
|
||||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||||
|
|||||||
5
api/tests/fixtures/file output schema.yml
vendored
5
api/tests/fixtures/file output schema.yml
vendored
@ -126,8 +126,9 @@ workflow:
|
|||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
image:
|
image:
|
||||||
description: Sandbox file path of the selected image
|
description: File ID (UUID) of the selected image
|
||||||
type: file
|
format: dify-file-ref
|
||||||
|
type: string
|
||||||
required:
|
required:
|
||||||
- image
|
- image
|
||||||
type: object
|
type: object
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from core.agent.entities import AgentScratchpadUnit
|
from core.agent.entities import AgentScratchpadUnit
|
||||||
@ -63,7 +64,7 @@ def test_cot_output_parser():
|
|||||||
output += result
|
output += result
|
||||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||||
if test_case["action"]:
|
if test_case["action"]:
|
||||||
assert result.model_dump() == test_case["action"]
|
assert result.to_dict() == test_case["action"]
|
||||||
output += result.model_dump_json()
|
output += json.dumps(result.to_dict())
|
||||||
if test_case["output"]:
|
if test_case["output"]:
|
||||||
assert output == test_case["output"]
|
assert output == test_case["output"]
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
|||||||
class ConcreteAgentPattern(AgentPattern):
|
class ConcreteAgentPattern(AgentPattern):
|
||||||
"""Concrete implementation of AgentPattern for testing."""
|
"""Concrete implementation of AgentPattern for testing."""
|
||||||
|
|
||||||
def run(self, prompt_messages, model_parameters, stop=[]):
|
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
|
||||||
"""Minimal implementation for testing."""
|
"""Minimal implementation for testing."""
|
||||||
yield from []
|
yield from []
|
||||||
|
|
||||||
|
|||||||
@ -329,15 +329,13 @@ class TestAgentLogProcessing:
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = AgentResult(
|
result = AgentResult(
|
||||||
output="Final answer",
|
text="Final answer",
|
||||||
files=[],
|
files=[],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
)
|
)
|
||||||
|
|
||||||
output_payload = result.output
|
assert result.text == "Final answer"
|
||||||
assert isinstance(output_payload, str)
|
|
||||||
assert output_payload == "Final answer"
|
|
||||||
assert result.files == []
|
assert result.files == []
|
||||||
assert result.usage == usage
|
assert result.usage == usage
|
||||||
assert result.finish_reason == "stop"
|
assert result.finish_reason == "stop"
|
||||||
|
|||||||
@ -153,7 +153,7 @@ class TestAgentScratchpadUnit:
|
|||||||
action_input={"query": "test"},
|
action_input={"query": "test"},
|
||||||
)
|
)
|
||||||
|
|
||||||
result = action.model_dump()
|
result = action.to_dict()
|
||||||
|
|
||||||
assert result == {
|
assert result == {
|
||||||
"action": "search",
|
"action": "search",
|
||||||
|
|||||||
@ -1,93 +1,269 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for sandbox file path detection and conversion.
|
Unit tests for file reference detection and conversion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.llm_generator.output_parser.file_ref import (
|
from core.llm_generator.output_parser.file_ref import (
|
||||||
FILE_PATH_DESCRIPTION_SUFFIX,
|
FILE_REF_FORMAT,
|
||||||
adapt_schema_for_sandbox_file_paths,
|
convert_file_refs_in_output,
|
||||||
convert_sandbox_file_paths_in_output,
|
detect_file_ref_fields,
|
||||||
detect_file_path_fields,
|
is_file_ref_property,
|
||||||
is_file_path_property,
|
|
||||||
)
|
)
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||||
|
|
||||||
|
|
||||||
def _build_file(file_id: str) -> File:
|
class TestIsFileRefProperty:
|
||||||
return File(
|
"""Tests for is_file_ref_property function."""
|
||||||
id=file_id,
|
|
||||||
tenant_id="tenant_123",
|
def test_valid_file_ref(self):
|
||||||
type=FileType.IMAGE,
|
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
assert is_file_ref_property(schema) is True
|
||||||
filename="test.png",
|
|
||||||
extension=".png",
|
def test_invalid_type(self):
|
||||||
mime_type="image/png",
|
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||||
size=128,
|
assert is_file_ref_property(schema) is False
|
||||||
related_id=file_id,
|
|
||||||
storage_key="sandbox/path",
|
def test_missing_format(self):
|
||||||
)
|
schema = {"type": "string"}
|
||||||
|
assert is_file_ref_property(schema) is False
|
||||||
|
|
||||||
|
def test_wrong_format(self):
|
||||||
|
schema = {"type": "string", "format": "uuid"}
|
||||||
|
assert is_file_ref_property(schema) is False
|
||||||
|
|
||||||
|
|
||||||
class TestFilePathSchema:
|
class TestDetectFileRefFields:
|
||||||
def test_is_file_path_property(self):
|
"""Tests for detect_file_ref_fields function."""
|
||||||
assert is_file_path_property({"type": "file"}) is True
|
|
||||||
assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True
|
|
||||||
assert is_file_path_property({"type": "string"}) is False
|
|
||||||
|
|
||||||
def test_detect_file_path_fields(self):
|
def test_simple_file_ref(self):
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"image": {"type": "string", "format": "dify-file-ref"},
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
"files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}},
|
|
||||||
"meta": {"type": "object", "properties": {"doc": {"type": "file"}}},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"}
|
paths = detect_file_ref_fields(schema)
|
||||||
|
assert paths == ["image"]
|
||||||
|
|
||||||
def test_adapt_schema_for_sandbox_file_paths(self):
|
def test_multiple_file_refs(self):
|
||||||
schema = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"image": {"type": "string", "format": "dify-file-ref"},
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
"name": {"type": "string"},
|
"name": {"type": "string"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
adapted, fields = adapt_schema_for_sandbox_file_paths(schema)
|
paths = detect_file_ref_fields(schema)
|
||||||
|
assert set(paths) == {"image", "document"}
|
||||||
|
|
||||||
assert set(fields) == {"image"}
|
def test_array_of_file_refs(self):
|
||||||
adapted_image = adapted["properties"]["image"]
|
schema = {
|
||||||
assert adapted_image["type"] == "string"
|
"type": "object",
|
||||||
assert "format" not in adapted_image
|
"properties": {
|
||||||
assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"]
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
paths = detect_file_ref_fields(schema)
|
||||||
|
assert paths == ["files[*]"]
|
||||||
|
|
||||||
|
def test_nested_file_ref(self):
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
paths = detect_file_ref_fields(schema)
|
||||||
|
assert paths == ["data.image"]
|
||||||
|
|
||||||
|
def test_no_file_refs(self):
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"count": {"type": "number"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
paths = detect_file_ref_fields(schema)
|
||||||
|
assert paths == []
|
||||||
|
|
||||||
|
def test_empty_schema(self):
|
||||||
|
schema = {}
|
||||||
|
paths = detect_file_ref_fields(schema)
|
||||||
|
assert paths == []
|
||||||
|
|
||||||
|
def test_mixed_schema(self):
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
"documents": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
paths = detect_file_ref_fields(schema)
|
||||||
|
assert set(paths) == {"image", "documents[*]"}
|
||||||
|
|
||||||
|
|
||||||
class TestConvertSandboxFilePaths:
|
class TestConvertFileRefsInOutput:
|
||||||
def test_convert_sandbox_file_paths(self):
|
"""Tests for convert_file_refs_in_output function."""
|
||||||
output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"}
|
|
||||||
|
|
||||||
def resolver(path: str) -> File:
|
@pytest.fixture
|
||||||
return _build_file(path)
|
def mock_file(self):
|
||||||
|
"""Create a mock File object with all required attributes."""
|
||||||
|
file = MagicMock(spec=File)
|
||||||
|
file.type = FileType.IMAGE
|
||||||
|
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||||
|
file.related_id = "test-related-id"
|
||||||
|
file.remote_url = None
|
||||||
|
file.tenant_id = "tenant_123"
|
||||||
|
file.id = None
|
||||||
|
file.filename = "test.png"
|
||||||
|
file.extension = ".png"
|
||||||
|
file.mime_type = "image/png"
|
||||||
|
file.size = 1024
|
||||||
|
file.dify_model_identity = "__dify__file__"
|
||||||
|
return file
|
||||||
|
|
||||||
converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver)
|
@pytest.fixture
|
||||||
|
def mock_build_from_mapping(self, mock_file):
|
||||||
|
"""Mock the build_from_mapping function."""
|
||||||
|
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||||
|
mock.return_value = mock_file
|
||||||
|
yield mock
|
||||||
|
|
||||||
assert isinstance(converted["image"], FileSegment)
|
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||||
assert isinstance(converted["files"], ArrayFileSegment)
|
file_id = str(uuid.uuid4())
|
||||||
assert converted["name"] == "demo"
|
output = {"image": file_id}
|
||||||
assert [file.id for file in files] == ["a.png", "b.png", "c.png"]
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def test_invalid_path_value_raises(self):
|
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||||
def resolver(path: str) -> File:
|
|
||||||
return _build_file(path)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
# Result should be wrapped in FileSegment
|
||||||
convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver)
|
assert isinstance(result["image"], FileSegment)
|
||||||
|
assert result["image"].value == mock_file
|
||||||
|
mock_build_from_mapping.assert_called_once_with(
|
||||||
|
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||||
|
tenant_id="tenant_123",
|
||||||
|
)
|
||||||
|
|
||||||
def test_no_file_paths_returns_output(self):
|
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||||
output = {"name": "demo"}
|
file_id1 = str(uuid.uuid4())
|
||||||
converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file)
|
file_id2 = str(uuid.uuid4())
|
||||||
|
output = {"files": [file_id1, file_id2]}
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
assert converted == output
|
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||||
assert files == []
|
|
||||||
|
# Result should be wrapped in ArrayFileSegment
|
||||||
|
assert isinstance(result["files"], ArrayFileSegment)
|
||||||
|
assert list(result["files"].value) == [mock_file, mock_file]
|
||||||
|
assert mock_build_from_mapping.call_count == 2
|
||||||
|
|
||||||
|
def test_no_conversion_without_file_refs(self):
|
||||||
|
output = {"name": "test", "count": 5}
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"count": {"type": "number"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||||
|
|
||||||
|
assert result == {"name": "test", "count": 5}
|
||||||
|
|
||||||
|
def test_invalid_uuid_returns_none(self):
|
||||||
|
output = {"image": "not-a-valid-uuid"}
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||||
|
|
||||||
|
assert result["image"] is None
|
||||||
|
|
||||||
|
def test_file_not_found_returns_none(self):
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
output = {"image": file_id}
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||||
|
mock.side_effect = ValueError("File not found")
|
||||||
|
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||||
|
|
||||||
|
assert result["image"] is None
|
||||||
|
|
||||||
|
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
output = {"query": "search term", "image": file_id, "count": 10}
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
"count": {"type": "number"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||||
|
|
||||||
|
assert result["query"] == "search term"
|
||||||
|
assert isinstance(result["image"], FileSegment)
|
||||||
|
assert result["image"].value == mock_file
|
||||||
|
assert result["count"] == 10
|
||||||
|
|
||||||
|
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
original = {"image": file_id}
|
||||||
|
output = dict(original)
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||||
|
|
||||||
|
# Original should still contain the string ID
|
||||||
|
assert original["image"] == file_id
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
from collections.abc import Mapping
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, NotRequired, TypedDict
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,14 +6,16 @@ from pydantic import BaseModel, ConfigDict
|
|||||||
|
|
||||||
from core.llm_generator.output_parser.errors import OutputParserError
|
from core.llm_generator.output_parser.errors import OutputParserError
|
||||||
from core.llm_generator.output_parser.structured_output import (
|
from core.llm_generator.output_parser.structured_output import (
|
||||||
|
_get_default_value_for_type,
|
||||||
fill_defaults_from_schema,
|
fill_defaults_from_schema,
|
||||||
get_default_value_for_type,
|
|
||||||
invoke_llm_with_pydantic_model,
|
invoke_llm_with_pydantic_model,
|
||||||
invoke_llm_with_structured_output,
|
invoke_llm_with_structured_output,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.common_entities import I18nObject
|
|
||||||
from core.model_runtime.entities.llm_entities import (
|
from core.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
LLMResultChunkWithStructuredOutput,
|
||||||
LLMResultWithStructuredOutput,
|
LLMResultWithStructuredOutput,
|
||||||
LLMUsage,
|
LLMUsage,
|
||||||
)
|
)
|
||||||
@ -25,29 +25,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
AIModelEntity,
|
|
||||||
ModelType,
|
|
||||||
ParameterRule,
|
|
||||||
ParameterType,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredOutputTestCase(TypedDict):
|
|
||||||
name: str
|
|
||||||
provider: str
|
|
||||||
model_name: str
|
|
||||||
support_structure_output: bool
|
|
||||||
stream: bool
|
|
||||||
json_schema: Mapping[str, Any]
|
|
||||||
expected_llm_response: LLMResult
|
|
||||||
expected_result_type: type[LLMResultWithStructuredOutput] | None
|
|
||||||
should_raise: bool
|
|
||||||
expected_error: NotRequired[type[OutputParserError]]
|
|
||||||
parameter_rules: NotRequired[list[ParameterRule]]
|
|
||||||
|
|
||||||
|
|
||||||
SchemaData = dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
||||||
@ -93,7 +71,7 @@ def get_model_instance() -> MagicMock:
|
|||||||
def test_structured_output_parser():
|
def test_structured_output_parser():
|
||||||
"""Test cases for invoke_llm_with_structured_output function"""
|
"""Test cases for invoke_llm_with_structured_output function"""
|
||||||
|
|
||||||
testcases: list[StructuredOutputTestCase] = [
|
testcases = [
|
||||||
# Test case 1: Model with native structured output support, non-streaming
|
# Test case 1: Model with native structured output support, non-streaming
|
||||||
{
|
{
|
||||||
"name": "native_structured_output_non_streaming",
|
"name": "native_structured_output_non_streaming",
|
||||||
@ -110,6 +88,39 @@ def test_structured_output_parser():
|
|||||||
"expected_result_type": LLMResultWithStructuredOutput,
|
"expected_result_type": LLMResultWithStructuredOutput,
|
||||||
"should_raise": False,
|
"should_raise": False,
|
||||||
},
|
},
|
||||||
|
# Test case 2: Model with native structured output support, streaming
|
||||||
|
{
|
||||||
|
"name": "native_structured_output_streaming",
|
||||||
|
"provider": "openai",
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"support_structure_output": True,
|
||||||
|
"stream": True,
|
||||||
|
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||||
|
"expected_llm_response": [
|
||||||
|
LLMResultChunk(
|
||||||
|
model="gpt-4o",
|
||||||
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
|
system_fingerprint="test",
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content='{"name":'),
|
||||||
|
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
LLMResultChunk(
|
||||||
|
model="gpt-4o",
|
||||||
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
|
system_fingerprint="test",
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=' "test"}'),
|
||||||
|
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"expected_result_type": "generator",
|
||||||
|
"should_raise": False,
|
||||||
|
},
|
||||||
# Test case 3: Model without native structured output support, non-streaming
|
# Test case 3: Model without native structured output support, non-streaming
|
||||||
{
|
{
|
||||||
"name": "prompt_based_structured_output_non_streaming",
|
"name": "prompt_based_structured_output_non_streaming",
|
||||||
@ -126,24 +137,78 @@ def test_structured_output_parser():
|
|||||||
"expected_result_type": LLMResultWithStructuredOutput,
|
"expected_result_type": LLMResultWithStructuredOutput,
|
||||||
"should_raise": False,
|
"should_raise": False,
|
||||||
},
|
},
|
||||||
|
# Test case 4: Model without native structured output support, streaming
|
||||||
{
|
{
|
||||||
"name": "non_streaming_with_list_content",
|
"name": "prompt_based_structured_output_streaming",
|
||||||
|
"provider": "anthropic",
|
||||||
|
"model_name": "claude-3-sonnet",
|
||||||
|
"support_structure_output": False,
|
||||||
|
"stream": True,
|
||||||
|
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||||
|
"expected_llm_response": [
|
||||||
|
LLMResultChunk(
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
|
system_fingerprint="test",
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||||
|
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
LLMResultChunk(
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
|
system_fingerprint="test",
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=' response"}'),
|
||||||
|
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"expected_result_type": "generator",
|
||||||
|
"should_raise": False,
|
||||||
|
},
|
||||||
|
# Test case 5: Streaming with list content
|
||||||
|
{
|
||||||
|
"name": "streaming_with_list_content",
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
"model_name": "gpt-4o",
|
"model_name": "gpt-4o",
|
||||||
"support_structure_output": True,
|
"support_structure_output": True,
|
||||||
"stream": False,
|
"stream": True,
|
||||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||||
"expected_llm_response": LLMResult(
|
"expected_llm_response": [
|
||||||
model="gpt-4o",
|
LLMResultChunk(
|
||||||
message=AssistantPromptMessage(
|
model="gpt-4o",
|
||||||
content=[
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
TextPromptMessageContent(data='{"data":'),
|
system_fingerprint="test",
|
||||||
TextPromptMessageContent(data=' "value"}'),
|
delta=LLMResultChunkDelta(
|
||||||
]
|
index=0,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content=[
|
||||||
|
TextPromptMessageContent(data='{"data":'),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
LLMResultChunk(
|
||||||
),
|
model="gpt-4o",
|
||||||
"expected_result_type": LLMResultWithStructuredOutput,
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
|
system_fingerprint="test",
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(
|
||||||
|
content=[
|
||||||
|
TextPromptMessageContent(data=' "value"}'),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"expected_result_type": "generator",
|
||||||
"should_raise": False,
|
"should_raise": False,
|
||||||
},
|
},
|
||||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||||
@ -188,13 +253,7 @@ def test_structured_output_parser():
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
||||||
"parameter_rules": [
|
"parameter_rules": [
|
||||||
ParameterRule(
|
MagicMock(name="response_format", options=["json_schema"], required=False),
|
||||||
name="response_format",
|
|
||||||
label=I18nObject(en_US="response_format"),
|
|
||||||
type=ParameterType.STRING,
|
|
||||||
required=False,
|
|
||||||
options=["json_schema"],
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
"expected_llm_response": LLMResult(
|
"expected_llm_response": LLMResult(
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
@ -213,13 +272,7 @@ def test_structured_output_parser():
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
||||||
"parameter_rules": [
|
"parameter_rules": [
|
||||||
ParameterRule(
|
MagicMock(name="response_format", options=["JSON"], required=False),
|
||||||
name="response_format",
|
|
||||||
label=I18nObject(en_US="response_format"),
|
|
||||||
type=ParameterType.STRING,
|
|
||||||
required=False,
|
|
||||||
options=["JSON"],
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
"expected_llm_response": LLMResult(
|
"expected_llm_response": LLMResult(
|
||||||
model="claude-3-sonnet",
|
model="claude-3-sonnet",
|
||||||
@ -232,72 +285,89 @@ def test_structured_output_parser():
|
|||||||
]
|
]
|
||||||
|
|
||||||
for case in testcases:
|
for case in testcases:
|
||||||
provider = case["provider"]
|
# Setup model entity
|
||||||
model_name = case["model_name"]
|
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||||
support_structure_output = case["support_structure_output"]
|
|
||||||
json_schema = case["json_schema"]
|
|
||||||
stream = case["stream"]
|
|
||||||
|
|
||||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
# Add parameter rules if specified
|
||||||
|
if "parameter_rules" in case:
|
||||||
parameter_rules = case.get("parameter_rules")
|
model_schema.parameter_rules = case["parameter_rules"]
|
||||||
if parameter_rules is not None:
|
|
||||||
model_schema.parameter_rules = parameter_rules
|
|
||||||
|
|
||||||
|
# Setup model instance
|
||||||
model_instance = get_model_instance()
|
model_instance = get_model_instance()
|
||||||
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
||||||
|
|
||||||
|
# Setup prompt messages
|
||||||
prompt_messages = [
|
prompt_messages = [
|
||||||
SystemPromptMessage(content="You are a helpful assistant."),
|
SystemPromptMessage(content="You are a helpful assistant."),
|
||||||
UserPromptMessage(content="Generate a response according to the schema."),
|
UserPromptMessage(content="Generate a response according to the schema."),
|
||||||
]
|
]
|
||||||
|
|
||||||
if case["should_raise"]:
|
if case["should_raise"]:
|
||||||
expected_error = case.get("expected_error", OutputParserError)
|
# Test error cases
|
||||||
with pytest.raises(expected_error): # noqa: PT012
|
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||||
if stream:
|
if case["stream"]:
|
||||||
result_generator = invoke_llm_with_structured_output(
|
result_generator = invoke_llm_with_structured_output(
|
||||||
provider=provider,
|
provider=case["provider"],
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
json_schema=json_schema,
|
json_schema=case["json_schema"],
|
||||||
)
|
)
|
||||||
|
# Consume the generator to trigger the error
|
||||||
list(result_generator)
|
list(result_generator)
|
||||||
else:
|
else:
|
||||||
invoke_llm_with_structured_output(
|
invoke_llm_with_structured_output(
|
||||||
provider=provider,
|
provider=case["provider"],
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
json_schema=json_schema,
|
json_schema=case["json_schema"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Test successful cases
|
||||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||||
|
# Configure json_repair mock for cases that need it
|
||||||
if case["name"] == "json_repair_scenario":
|
if case["name"] == "json_repair_scenario":
|
||||||
mock_json_repair.return_value = {"name": "test"}
|
mock_json_repair.return_value = {"name": "test"}
|
||||||
|
|
||||||
result = invoke_llm_with_structured_output(
|
result = invoke_llm_with_structured_output(
|
||||||
provider=provider,
|
provider=case["provider"],
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
json_schema=json_schema,
|
json_schema=case["json_schema"],
|
||||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||||
user="test_user",
|
user="test_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_result_type = case["expected_result_type"]
|
if case["expected_result_type"] == "generator":
|
||||||
assert expected_result_type is not None
|
# Test streaming results
|
||||||
assert isinstance(result, expected_result_type)
|
assert hasattr(result, "__iter__")
|
||||||
assert result.model == model_name
|
chunks = list(result)
|
||||||
assert result.structured_output is not None
|
assert len(chunks) > 0
|
||||||
assert isinstance(result.structured_output, dict)
|
|
||||||
|
|
||||||
|
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||||
|
for chunk in chunks[:-1]: # All except last
|
||||||
|
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||||
|
assert chunk.model == case["model_name"]
|
||||||
|
|
||||||
|
# Last chunk should have structured output
|
||||||
|
last_chunk = chunks[-1]
|
||||||
|
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||||
|
assert last_chunk.structured_output is not None
|
||||||
|
assert isinstance(last_chunk.structured_output, dict)
|
||||||
|
else:
|
||||||
|
# Test non-streaming results
|
||||||
|
assert isinstance(result, case["expected_result_type"])
|
||||||
|
assert result.model == case["model_name"]
|
||||||
|
assert result.structured_output is not None
|
||||||
|
assert isinstance(result.structured_output, dict)
|
||||||
|
|
||||||
|
# Verify model_instance.invoke_llm was called with correct parameters
|
||||||
model_instance.invoke_llm.assert_called_once()
|
model_instance.invoke_llm.assert_called_once()
|
||||||
call_args = model_instance.invoke_llm.call_args
|
call_args = model_instance.invoke_llm.call_args
|
||||||
|
|
||||||
assert call_args.kwargs["stream"] == stream
|
assert call_args.kwargs["stream"] == case["stream"]
|
||||||
assert call_args.kwargs["user"] == "test_user"
|
assert call_args.kwargs["user"] == "test_user"
|
||||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||||
@ -306,32 +376,45 @@ def test_structured_output_parser():
|
|||||||
def test_parse_structured_output_edge_cases():
|
def test_parse_structured_output_edge_cases():
|
||||||
"""Test edge cases for structured output parsing"""
|
"""Test edge cases for structured output parsing"""
|
||||||
|
|
||||||
provider = "deepseek"
|
# Test case with list that contains dict (reasoning model scenario)
|
||||||
model_name = "deepseek-r1"
|
testcase_list_with_dict = {
|
||||||
support_structure_output = False
|
"name": "list_with_dict_parsing",
|
||||||
json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}}
|
"provider": "deepseek",
|
||||||
expected_llm_response = LLMResult(
|
"model_name": "deepseek-r1",
|
||||||
model="deepseek-r1",
|
"support_structure_output": False,
|
||||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
"stream": False,
|
||||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
|
||||||
|
"expected_llm_response": LLMResult(
|
||||||
|
model="deepseek-r1",
|
||||||
|
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||||
|
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||||
|
),
|
||||||
|
"expected_result_type": LLMResultWithStructuredOutput,
|
||||||
|
"should_raise": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup for list parsing test
|
||||||
|
model_schema = get_model_entity(
|
||||||
|
testcase_list_with_dict["provider"],
|
||||||
|
testcase_list_with_dict["model_name"],
|
||||||
|
testcase_list_with_dict["support_structure_output"],
|
||||||
)
|
)
|
||||||
|
|
||||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
|
||||||
|
|
||||||
model_instance = get_model_instance()
|
model_instance = get_model_instance()
|
||||||
model_instance.invoke_llm.return_value = expected_llm_response
|
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
|
||||||
|
|
||||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||||
|
|
||||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||||
|
# Mock json_repair to return a list with dict
|
||||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||||
|
|
||||||
result = invoke_llm_with_structured_output(
|
result = invoke_llm_with_structured_output(
|
||||||
provider=provider,
|
provider=testcase_list_with_dict["provider"],
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
json_schema=json_schema,
|
json_schema=testcase_list_with_dict["json_schema"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||||
@ -341,16 +424,18 @@ def test_parse_structured_output_edge_cases():
|
|||||||
def test_model_specific_schema_preparation():
|
def test_model_specific_schema_preparation():
|
||||||
"""Test schema preparation for different model types"""
|
"""Test schema preparation for different model types"""
|
||||||
|
|
||||||
provider = "google"
|
# Test Gemini model
|
||||||
model_name = "gemini-pro"
|
gemini_case = {
|
||||||
support_structure_output = True
|
"provider": "google",
|
||||||
json_schema: SchemaData = {
|
"model_name": "gemini-pro",
|
||||||
"type": "object",
|
"support_structure_output": True,
|
||||||
"properties": {"result": {"type": "boolean"}},
|
"stream": False,
|
||||||
"additionalProperties": False,
|
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
model_schema = get_model_entity(
|
||||||
|
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
|
||||||
|
)
|
||||||
|
|
||||||
model_instance = get_model_instance()
|
model_instance = get_model_instance()
|
||||||
model_instance.invoke_llm.return_value = LLMResult(
|
model_instance.invoke_llm.return_value = LLMResult(
|
||||||
@ -362,11 +447,11 @@ def test_model_specific_schema_preparation():
|
|||||||
prompt_messages = [UserPromptMessage(content="Test")]
|
prompt_messages = [UserPromptMessage(content="Test")]
|
||||||
|
|
||||||
result = invoke_llm_with_structured_output(
|
result = invoke_llm_with_structured_output(
|
||||||
provider=provider,
|
provider=gemini_case["provider"],
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
json_schema=json_schema,
|
json_schema=gemini_case["json_schema"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||||
@ -408,26 +493,40 @@ def test_structured_output_with_pydantic_model_non_streaming():
|
|||||||
assert result.name == "test"
|
assert result.name == "test"
|
||||||
|
|
||||||
|
|
||||||
def test_structured_output_with_pydantic_model_list_content():
|
def test_structured_output_with_pydantic_model_streaming():
|
||||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||||
model_instance = get_model_instance()
|
model_instance = get_model_instance()
|
||||||
model_instance.invoke_llm.return_value = LLMResult(
|
|
||||||
model="gpt-4o",
|
def mock_streaming_response():
|
||||||
message=AssistantPromptMessage(
|
yield LLMResultChunk(
|
||||||
content=[
|
model="gpt-4o",
|
||||||
TextPromptMessageContent(data='{"name":'),
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
TextPromptMessageContent(data=' "test"}'),
|
system_fingerprint="test",
|
||||||
]
|
delta=LLMResultChunkDelta(
|
||||||
),
|
index=0,
|
||||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
message=AssistantPromptMessage(content='{"name":'),
|
||||||
)
|
usage=create_mock_usage(prompt_tokens=8, completion_tokens=2),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield LLMResultChunk(
|
||||||
|
model="gpt-4o",
|
||||||
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
|
system_fingerprint="test",
|
||||||
|
delta=LLMResultChunkDelta(
|
||||||
|
index=0,
|
||||||
|
message=AssistantPromptMessage(content=' "test"}'),
|
||||||
|
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
model_instance.invoke_llm.return_value = mock_streaming_response()
|
||||||
|
|
||||||
result = invoke_llm_with_pydantic_model(
|
result = invoke_llm_with_pydantic_model(
|
||||||
provider="openai",
|
provider="openai",
|
||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
|
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
|
||||||
output_model=ExampleOutput,
|
output_model=ExampleOutput
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, ExampleOutput)
|
assert isinstance(result, ExampleOutput)
|
||||||
@ -449,51 +548,51 @@ def test_structured_output_with_pydantic_model_validation_error():
|
|||||||
model_schema=model_schema,
|
model_schema=model_schema,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=[UserPromptMessage(content="test")],
|
prompt_messages=[UserPromptMessage(content="test")],
|
||||||
output_model=ExampleOutput,
|
output_model=ExampleOutput
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestGetDefaultValueForType:
|
class TestGetDefaultValueForType:
|
||||||
"""Test cases for get_default_value_for_type function"""
|
"""Test cases for _get_default_value_for_type function"""
|
||||||
|
|
||||||
def test_string_type(self):
|
def test_string_type(self):
|
||||||
assert get_default_value_for_type("string") == ""
|
assert _get_default_value_for_type("string") == ""
|
||||||
|
|
||||||
def test_object_type(self):
|
def test_object_type(self):
|
||||||
assert get_default_value_for_type("object") == {}
|
assert _get_default_value_for_type("object") == {}
|
||||||
|
|
||||||
def test_array_type(self):
|
def test_array_type(self):
|
||||||
assert get_default_value_for_type("array") == []
|
assert _get_default_value_for_type("array") == []
|
||||||
|
|
||||||
def test_number_type(self):
|
def test_number_type(self):
|
||||||
assert get_default_value_for_type("number") == 0
|
assert _get_default_value_for_type("number") == 0
|
||||||
|
|
||||||
def test_integer_type(self):
|
def test_integer_type(self):
|
||||||
assert get_default_value_for_type("integer") == 0
|
assert _get_default_value_for_type("integer") == 0
|
||||||
|
|
||||||
def test_boolean_type(self):
|
def test_boolean_type(self):
|
||||||
assert get_default_value_for_type("boolean") is False
|
assert _get_default_value_for_type("boolean") is False
|
||||||
|
|
||||||
def test_null_type(self):
|
def test_null_type(self):
|
||||||
assert get_default_value_for_type("null") is None
|
assert _get_default_value_for_type("null") is None
|
||||||
|
|
||||||
def test_none_type(self):
|
def test_none_type(self):
|
||||||
assert get_default_value_for_type(None) is None
|
assert _get_default_value_for_type(None) is None
|
||||||
|
|
||||||
def test_unknown_type(self):
|
def test_unknown_type(self):
|
||||||
assert get_default_value_for_type("unknown") is None
|
assert _get_default_value_for_type("unknown") is None
|
||||||
|
|
||||||
def test_union_type_string_null(self):
|
def test_union_type_string_null(self):
|
||||||
# ["string", "null"] should return "" (first non-null type)
|
# ["string", "null"] should return "" (first non-null type)
|
||||||
assert get_default_value_for_type(["string", "null"]) == ""
|
assert _get_default_value_for_type(["string", "null"]) == ""
|
||||||
|
|
||||||
def test_union_type_null_first(self):
|
def test_union_type_null_first(self):
|
||||||
# ["null", "integer"] should return 0 (first non-null type)
|
# ["null", "integer"] should return 0 (first non-null type)
|
||||||
assert get_default_value_for_type(["null", "integer"]) == 0
|
assert _get_default_value_for_type(["null", "integer"]) == 0
|
||||||
|
|
||||||
def test_union_type_only_null(self):
|
def test_union_type_only_null(self):
|
||||||
# ["null"] should return None
|
# ["null"] should return None
|
||||||
assert get_default_value_for_type(["null"]) is None
|
assert _get_default_value_for_type(["null"]) is None
|
||||||
|
|
||||||
|
|
||||||
class TestFillDefaultsFromSchema:
|
class TestFillDefaultsFromSchema:
|
||||||
@ -501,7 +600,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_simple_required_fields(self):
|
def test_simple_required_fields(self):
|
||||||
"""Test filling simple required fields"""
|
"""Test filling simple required fields"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"name": {"type": "string"},
|
"name": {"type": "string"},
|
||||||
@ -510,7 +609,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
},
|
},
|
||||||
"required": ["name", "age"],
|
"required": ["name", "age"],
|
||||||
}
|
}
|
||||||
output: SchemaData = {"name": "Alice"}
|
output = {"name": "Alice"}
|
||||||
|
|
||||||
result = fill_defaults_from_schema(output, schema)
|
result = fill_defaults_from_schema(output, schema)
|
||||||
|
|
||||||
@ -520,7 +619,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_non_required_fields_not_filled(self):
|
def test_non_required_fields_not_filled(self):
|
||||||
"""Test that non-required fields are not filled"""
|
"""Test that non-required fields are not filled"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"required_field": {"type": "string"},
|
"required_field": {"type": "string"},
|
||||||
@ -528,7 +627,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
},
|
},
|
||||||
"required": ["required_field"],
|
"required": ["required_field"],
|
||||||
}
|
}
|
||||||
output: SchemaData = {}
|
output = {}
|
||||||
|
|
||||||
result = fill_defaults_from_schema(output, schema)
|
result = fill_defaults_from_schema(output, schema)
|
||||||
|
|
||||||
@ -537,7 +636,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_nested_object_required_fields(self):
|
def test_nested_object_required_fields(self):
|
||||||
"""Test filling nested object required fields"""
|
"""Test filling nested object required fields"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"user": {
|
"user": {
|
||||||
@ -560,7 +659,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
},
|
},
|
||||||
"required": ["user"],
|
"required": ["user"],
|
||||||
}
|
}
|
||||||
output: SchemaData = {
|
output = {
|
||||||
"user": {
|
"user": {
|
||||||
"name": "Alice",
|
"name": "Alice",
|
||||||
"address": {
|
"address": {
|
||||||
@ -585,7 +684,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_missing_nested_object_created(self):
|
def test_missing_nested_object_created(self):
|
||||||
"""Test that missing required nested objects are created"""
|
"""Test that missing required nested objects are created"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -599,7 +698,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
},
|
},
|
||||||
"required": ["metadata"],
|
"required": ["metadata"],
|
||||||
}
|
}
|
||||||
output: SchemaData = {}
|
output = {}
|
||||||
|
|
||||||
result = fill_defaults_from_schema(output, schema)
|
result = fill_defaults_from_schema(output, schema)
|
||||||
|
|
||||||
@ -611,7 +710,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_all_types_default_values(self):
|
def test_all_types_default_values(self):
|
||||||
"""Test default values for all types"""
|
"""Test default values for all types"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"str_field": {"type": "string"},
|
"str_field": {"type": "string"},
|
||||||
@ -623,7 +722,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
},
|
},
|
||||||
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
|
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
|
||||||
}
|
}
|
||||||
output: SchemaData = {}
|
output = {}
|
||||||
|
|
||||||
result = fill_defaults_from_schema(output, schema)
|
result = fill_defaults_from_schema(output, schema)
|
||||||
|
|
||||||
@ -638,7 +737,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_existing_values_preserved(self):
|
def test_existing_values_preserved(self):
|
||||||
"""Test that existing values are not overwritten"""
|
"""Test that existing values are not overwritten"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"name": {"type": "string"},
|
"name": {"type": "string"},
|
||||||
@ -646,7 +745,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
},
|
},
|
||||||
"required": ["name", "count"],
|
"required": ["name", "count"],
|
||||||
}
|
}
|
||||||
output: SchemaData = {"name": "Bob", "count": 42}
|
output = {"name": "Bob", "count": 42}
|
||||||
|
|
||||||
result = fill_defaults_from_schema(output, schema)
|
result = fill_defaults_from_schema(output, schema)
|
||||||
|
|
||||||
@ -654,7 +753,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_complex_nested_structure(self):
|
def test_complex_nested_structure(self):
|
||||||
"""Test complex nested structure with multiple levels"""
|
"""Test complex nested structure with multiple levels"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"user": {
|
"user": {
|
||||||
@ -690,7 +789,7 @@ class TestFillDefaultsFromSchema:
|
|||||||
},
|
},
|
||||||
"required": ["user", "tags", "metadata", "is_active"],
|
"required": ["user", "tags", "metadata", "is_active"],
|
||||||
}
|
}
|
||||||
output: SchemaData = {
|
output = {
|
||||||
"user": {
|
"user": {
|
||||||
"name": "Alice",
|
"name": "Alice",
|
||||||
"age": 25,
|
"age": 25,
|
||||||
@ -730,8 +829,8 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_empty_schema(self):
|
def test_empty_schema(self):
|
||||||
"""Test with empty schema"""
|
"""Test with empty schema"""
|
||||||
schema: SchemaData = {}
|
schema = {}
|
||||||
output: SchemaData = {"any": "value"}
|
output = {"any": "value"}
|
||||||
|
|
||||||
result = fill_defaults_from_schema(output, schema)
|
result = fill_defaults_from_schema(output, schema)
|
||||||
|
|
||||||
@ -739,14 +838,14 @@ class TestFillDefaultsFromSchema:
|
|||||||
|
|
||||||
def test_schema_without_required(self):
|
def test_schema_without_required(self):
|
||||||
"""Test schema without required field"""
|
"""Test schema without required field"""
|
||||||
schema: SchemaData = {
|
schema = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"optional1": {"type": "string"},
|
"optional1": {"type": "string"},
|
||||||
"optional2": {"type": "integer"},
|
"optional2": {"type": "integer"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
output: SchemaData = {}
|
output = {}
|
||||||
|
|
||||||
result = fill_defaults_from_schema(output, schema)
|
result = fill_defaults_from_schema(output, schema)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user