mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/agent-node-v2' into deploy/dev
This commit is contained in:
commit
960b0707c8
|
|
@ -202,6 +202,7 @@ message_detail_model = console_ns.model(
|
|||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ def build_message_model(api_or_ns: Namespace):
|
|||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
}
|
||||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ class MessageListApi(WebApiResource):
|
|||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,380 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
|
|
@ -5,7 +5,7 @@ from typing import Union, cast
|
|||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
|
@ -114,9 +114,20 @@ class BaseAgentRunner(AppRunner):
|
|||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
self.model_features = features
|
||||
self.query: str | None = ""
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
"""Build execution context."""
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_config.app_id,
|
||||
conversation_id=self.conversation.id,
|
||||
message_id=self.message.id,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
|
|
|
|||
|
|
@ -1,431 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Mapping[str, str],
|
||||
) -> Generator:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
agent_thought_id = "" # Initialize agent_thought_id
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
self._prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=[],
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
usage_dict: dict[str, LLMUsage | None] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
if usage_dict["usage"] is not None:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
observation="",
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except TypeError:
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
|
||||
action=scratchpad.action,
|
||||
tool_instances=tool_instances,
|
||||
message_file_ids=message_file_ids,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
scratchpad.observation = tool_invoke_response
|
||||
scratchpad.agent_response = tool_invoke_response
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought or "",
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
:param tool_instances: tool instances
|
||||
:param message_file_ids: message file ids
|
||||
:param trace_manager: trace manager
|
||||
:return: observation, meta
|
||||
"""
|
||||
# action is tool call, invoke tool
|
||||
tool_call_name = action.action_name
|
||||
tool_call_args = action.action_input
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
|
||||
if not tool_instance:
|
||||
answer = f"there is not a tool named {tool_call_name}"
|
||||
return answer, ToolInvokeMeta.error_instance(answer)
|
||||
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
fill in inputs from external data tools
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
|
||||
def _init_react_state(self, query):
|
||||
"""
|
||||
init agent scratchpad
|
||||
"""
|
||||
self._query = query
|
||||
self._agent_scratchpad = []
|
||||
self._historic_prompt_messages = self._organize_historic_prompt_messages()
|
||||
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
"""
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
else:
|
||||
message += f"Thought: {scratchpad.thought}\n\n"
|
||||
if scratchpad.action_str:
|
||||
message += f"Action: {scratchpad.action_str}\n\n"
|
||||
if scratchpad.observation:
|
||||
message += f"Observation: {scratchpad.observation}\n\n"
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotChatAgentRunner(CotAgentRunner):
|
||||
def _organize_system_prompt(self) -> SystemPromptMessage:
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
raise ValueError("Agent prompt configuration is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
|
||||
# query messages
|
||||
query_messages = self._organize_user_query(self._query, [])
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
|
||||
messages = [system_message, *historic_messages, *query_messages]
|
||||
|
||||
# join all messages
|
||||
return messages
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
class CotCompletionAgentRunner(CotAgentRunner):
|
||||
def _organize_instruction_prompt(self) -> str:
|
||||
"""
|
||||
Organize instruction prompt
|
||||
"""
|
||||
if self.app_config.agent is None:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if prompt_entity is None:
|
||||
raise ValueError("prompt entity is not set")
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
|
||||
"""
|
||||
Organize historic prompt
|
||||
"""
|
||||
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
if isinstance(message, UserPromptMessage):
|
||||
historic_prompt += f"Question: {message.content}\n\n"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
historic_prompt += message.content + "\n\n"
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
historic_prompt += content.data
|
||||
|
||||
return historic_prompt
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
"""
|
||||
# organize system prompt
|
||||
system_prompt = self._organize_instruction_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_prompt = self._organize_historic_prompt()
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad or []:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
# query messages
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
|
|
@ -92,3 +94,94 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
|||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
This context carries all the IDs and metadata that are not part of
|
||||
the core business logic but needed for tracing, auditing, and
|
||||
correlation purposes.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
"""Create a minimal context with only essential fields."""
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for passing to legacy code."""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
return ExecutionContext(
|
||||
user_id=data.get("user_id"),
|
||||
app_id=data.get("app_id"),
|
||||
conversation_id=data.get("conversation_id"),
|
||||
message_id=data.get("message_id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""
|
||||
Agent Log.
|
||||
"""
|
||||
|
||||
class LogType(StrEnum):
|
||||
"""Type of agent log entry."""
|
||||
|
||||
ROUND = "round" # A complete iteration round
|
||||
THOUGHT = "thought" # LLM thinking/reasoning
|
||||
TOOL_CALL = "tool_call" # Tool invocation
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
|
||||
label: str = Field(..., description="The label of the log")
|
||||
log_type: LogType = Field(..., description="The type of the log")
|
||||
parent_id: str | None = Field(default=None, description="Leave empty for root log")
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
text: str = Field(default="", description="The generated text")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
|
|
|||
|
|
@ -1,465 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
final_answer = ""
|
||||
prompt_messages: list = [] # Initialize prompt_messages
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
||||
model_instance = self.model_instance
|
||||
|
||||
while function_call_state and iteration_step <= max_iteration_steps:
|
||||
function_call_state = False
|
||||
|
||||
if iteration_step == max_iteration_steps:
|
||||
# the last iteration, remove all tools
|
||||
prompt_messages_tools = []
|
||||
|
||||
message_file_ids: list[str] = []
|
||||
agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except TypeError:
|
||||
# fallback: force ASCII to handle non-serializable objects
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
for content in chunk.delta.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += str(chunk.delta.message.content)
|
||||
|
||||
if chunk.delta.usage:
|
||||
increase_usage(llm_usage, chunk.delta.usage)
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except TypeError:
|
||||
# fallback: force ASCII to handle non-serializable objects
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
current_llm_usage = result.usage
|
||||
|
||||
if result.message and result.message.content:
|
||||
if isinstance(result.message.content, list):
|
||||
for content in result.message.content:
|
||||
response += content.data
|
||||
else:
|
||||
response += str(result.message.content)
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
tool_invoke_meta=None,
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
|
||||
tool_instance = tool_instances.get(tool_call_name)
|
||||
if not tool_instance:
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=str(tool_response["tool_response"]),
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=agent_thought_id,
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
thought="",
|
||||
tool_invoke_meta={
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
Check if there is any tool call in llm result chunk
|
||||
"""
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
"""
|
||||
if llm_result.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
|
||||
"""
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_messages and prompt_template:
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages or []
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Agent Patterns
|
||||
|
||||
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
|
||||
|
||||
## Overview
|
||||
|
||||
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Dual strategies**
|
||||
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
|
||||
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
|
||||
- **Explicit or auto selection**
|
||||
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
|
||||
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
|
||||
- **Unified execution contract**
|
||||
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
|
||||
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
|
||||
- **Tool handling and hooks**
|
||||
- Tools convert to `PromptMessageTool` objects before invocation.
|
||||
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
|
||||
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
|
||||
- **File-aware arguments**
|
||||
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
|
||||
- **ReAct prompt shaping**
|
||||
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
|
||||
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
|
||||
- **Observability and accounting**
|
||||
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
agent/patterns/
|
||||
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
|
||||
├── function_call.py # Native function-calling loop with tool execution
|
||||
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
|
||||
└── strategy_factory.py # Strategy selection by model features or explicit override
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
- For auto-selection:
|
||||
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
|
||||
- For explicit behavior:
|
||||
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
|
||||
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
|
||||
|
||||
## Integration Points
|
||||
|
||||
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
|
||||
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
|
||||
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
|
||||
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
|
|
@ -0,0 +1,444 @@
|
|||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: finished_at - started_at,
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine().generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
|
|
@ -0,0 +1,295 @@
|
|||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if isinstance(chunks, Generator):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
|
||||
"""Handle a single tool call and return response with files and meta."""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None
|
||||
|
|
@ -0,0 +1,415 @@
|
|||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
# Create a generator from the LLMResult
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Start tool log
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
# Find tool instance
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.file.models import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
|
@ -4,6 +4,7 @@ import re
|
|||
import time
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Union
|
||||
|
||||
|
|
@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
|
|
@ -70,13 +72,120 @@ from core.workflow.runtime import GraphRuntimeState
|
|||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEventBuffer:
|
||||
"""
|
||||
Buffer for recording stream events in order to reconstruct the generation sequence.
|
||||
Records the exact order of text chunks, thoughts, and tool calls as they stream.
|
||||
"""
|
||||
|
||||
# Accumulated reasoning content (each thought block is a separate element)
|
||||
reasoning_content: list[str] = field(default_factory=list)
|
||||
# Current reasoning buffer (accumulates until we see a different event type)
|
||||
_current_reasoning: str = ""
|
||||
# Tool calls with their details
|
||||
tool_calls: list[dict] = field(default_factory=list)
|
||||
# Tool call ID to index mapping for updating results
|
||||
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
|
||||
# Sequence of events in stream order
|
||||
sequence: list[dict] = field(default_factory=list)
|
||||
# Current position in answer text
|
||||
_content_position: int = 0
|
||||
# Track last event type to detect transitions
|
||||
_last_event_type: str | None = None
|
||||
|
||||
def _flush_current_reasoning(self) -> None:
|
||||
"""Flush accumulated reasoning to the list and add to sequence."""
|
||||
if self._current_reasoning.strip():
|
||||
self.reasoning_content.append(self._current_reasoning.strip())
|
||||
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
|
||||
self._current_reasoning = ""
|
||||
|
||||
def record_text_chunk(self, text: str) -> None:
|
||||
"""Record a text chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
text_len = len(text)
|
||||
start_pos = self._content_position
|
||||
|
||||
# If last event was also content, extend it; otherwise create new
|
||||
if self.sequence and self.sequence[-1].get("type") == "content":
|
||||
self.sequence[-1]["end"] = start_pos + text_len
|
||||
else:
|
||||
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
|
||||
|
||||
self._content_position += text_len
|
||||
self._last_event_type = "content"
|
||||
|
||||
def record_thought_chunk(self, text: str) -> None:
|
||||
"""Record a thought/reasoning chunk event."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Accumulate thought content
|
||||
self._current_reasoning += text
|
||||
self._last_event_type = "thought"
|
||||
|
||||
def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None:
|
||||
"""Record a tool call event."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
|
||||
# Flush any pending reasoning first
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
# Check if this tool call already exists (we might get multiple chunks)
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
# Update arguments if provided
|
||||
if tool_arguments:
|
||||
self.tool_calls[idx]["arguments"] = tool_arguments
|
||||
else:
|
||||
# New tool call
|
||||
tool_call = {
|
||||
"id": tool_call_id or "",
|
||||
"name": tool_name or "",
|
||||
"arguments": tool_arguments or "",
|
||||
"result": "",
|
||||
}
|
||||
self.tool_calls.append(tool_call)
|
||||
idx = len(self.tool_calls) - 1
|
||||
self._tool_call_id_map[tool_call_id] = idx
|
||||
self.sequence.append({"type": "tool_call", "index": idx})
|
||||
|
||||
self._last_event_type = "tool_call"
|
||||
|
||||
def record_tool_result(self, tool_call_id: str, result: str) -> None:
|
||||
"""Record a tool result event (update existing tool call)."""
|
||||
if not tool_call_id:
|
||||
return
|
||||
if tool_call_id in self._tool_call_id_map:
|
||||
idx = self._tool_call_id_map[tool_call_id]
|
||||
self.tool_calls[idx]["result"] = result
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Finalize the buffer, flushing any pending data."""
|
||||
if self._last_event_type == "thought":
|
||||
self._flush_current_reasoning()
|
||||
|
||||
def has_data(self) -> bool:
|
||||
"""Check if there's any meaningful data recorded."""
|
||||
return bool(self.reasoning_content or self.tool_calls or self.sequence)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
|
|
@ -144,6 +253,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
# Stream event buffer for recording generation sequence
|
||||
self._stream_buffer = StreamEventBuffer()
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
|
|
@ -383,7 +494,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events."""
|
||||
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
return
|
||||
|
|
@ -405,9 +516,45 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
|
||||
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
|
||||
# Record stream event based on chunk type
|
||||
chunk_type = event.chunk_type or ChunkType.TEXT
|
||||
match chunk_type:
|
||||
case ChunkType.TEXT:
|
||||
self._stream_buffer.record_text_chunk(delta_text)
|
||||
self._task_state.answer += delta_text
|
||||
case ChunkType.THOUGHT:
|
||||
# Reasoning should not be part of final answer text
|
||||
self._stream_buffer.record_thought_chunk(delta_text)
|
||||
case ChunkType.TOOL_CALL:
|
||||
self._stream_buffer.record_tool_call(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
)
|
||||
case ChunkType.TOOL_RESULT:
|
||||
self._stream_buffer.record_tool_result(
|
||||
tool_call_id=tool_call_id,
|
||||
result=delta_text,
|
||||
)
|
||||
self._task_state.answer += delta_text
|
||||
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
answer=delta_text,
|
||||
message_id=self._message_id,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type.value if event.chunk_type else None,
|
||||
tool_call_id=tool_call_id or None,
|
||||
tool_name=tool_name or None,
|
||||
tool_arguments=tool_arguments or None,
|
||||
tool_files=tool_files,
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
|
|
@ -775,6 +922,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
answer_text = self._strip_think_blocks(answer_text)
|
||||
if self._recorded_files:
|
||||
# Remove markdown image links since we're storing files separately
|
||||
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
|
||||
|
|
@ -826,6 +974,54 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
|
||||
self._save_generation_detail(session=session, message=message)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think_blocks(text: str) -> str:
|
||||
"""Remove <think>...</think> blocks (including their content) from text."""
|
||||
if not text or "<think" not in text.lower():
|
||||
return text
|
||||
|
||||
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
return clean_text
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Chatflow using stream event buffer.
|
||||
The buffer records the exact order of events as they streamed,
|
||||
allowing accurate reconstruction of the generation sequence.
|
||||
"""
|
||||
# Finalize the stream buffer to flush any pending data
|
||||
self._stream_buffer.finalize()
|
||||
|
||||
# Only save if there's meaningful data
|
||||
if not self._stream_buffer.has_data():
|
||||
return
|
||||
|
||||
reasoning_content = self._stream_buffer.reasoning_content
|
||||
tool_calls = self._stream_buffer.tool_calls
|
||||
sequence = self._stream_buffer.sequence
|
||||
|
||||
# Check if generation detail already exists for this message
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
|
||||
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
|
||||
tool_calls=json.dumps(tool_calls) if tool_calls else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@ from typing import cast
|
|||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
|
|
@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
|||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner):
|
|||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
runner = AgentAppRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
|
|
|||
|
|
@ -671,7 +671,7 @@ class WorkflowResponseConverter:
|
|||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
message_id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
ChunkType,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
|
|
@ -483,11 +484,27 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
if delta_text is None:
|
||||
return
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
|
||||
tool_arguments = tool_call.arguments if tool_call else None
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
|
||||
yield self._text_chunk_to_stream_response(
|
||||
text=delta_text,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
|
|
@ -650,16 +667,35 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||
session.add(workflow_app_log)
|
||||
|
||||
def _text_chunk_to_stream_response(
|
||||
self, text: str, from_variable_selector: list[str] | None = None
|
||||
self,
|
||||
text: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: ChunkType | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
) -> TextChunkStreamResponse:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
from core.app.entities.task_entities import ChunkType as ResponseChunkType
|
||||
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
|
||||
data=TextChunkStreamResponse.Data(
|
||||
text=text,
|
||||
from_variable_selector=from_variable_selector,
|
||||
chunk_type=ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files or [],
|
||||
tool_error=tool_error,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -455,12 +455,20 @@ class WorkflowBasedAppRunner:
|
|||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
||||
|
||||
if event.is_final and not event.chunk:
|
||||
return
|
||||
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
chunk_type=QueueChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
|
|
@ -177,6 +177,15 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
|||
error: str | None = None
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
|
|
@ -191,6 +200,16 @@ class QueueTextChunkEvent(AppQueueEvent):
|
|||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool streaming payloads
|
||||
tool_call: ToolCall | None = None
|
||||
"""structured tool call info"""
|
||||
tool_result: ToolResult | None = None
|
||||
"""structured tool result info"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -113,6 +113,24 @@ class MessageStreamResponse(StreamResponse):
|
|||
answer: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
|
||||
chunk_type: str | None = None
|
||||
"""type of the chunk: text, tool_call, tool_result, thought"""
|
||||
|
||||
# Tool call fields (when chunk_type == "tool_call")
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == "tool_result")
|
||||
tool_files: list[str] | None = None
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
|
@ -582,6 +600,15 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
|||
data: Data
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
|
|
@ -595,6 +622,24 @@ class TextChunkStreamResponse(StreamResponse):
|
|||
text: str
|
||||
from_variable_selector: list[str] | None = None
|
||||
|
||||
# Extended fields for Agent/Tool streaming
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_files: list[str] = Field(default_factory=list)
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
data: Data
|
||||
|
||||
|
|
@ -743,7 +788,7 @@ class AgentLogStreamResponse(StreamResponse):
|
|||
"""
|
||||
|
||||
node_execution_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
label: str
|
||||
parent_id: str | None = None
|
||||
error: str | None = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
|
|
@ -58,7 +59,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
|||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
|
|
@ -409,11 +412,136 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
)
|
||||
)
|
||||
|
||||
# Save LLM generation detail if there's reasoning_content
|
||||
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
|
||||
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
|
||||
For Agent-Chat, also merges MessageAgentThought records.
|
||||
"""
|
||||
import json
|
||||
|
||||
reasoning_list: list[str] = []
|
||||
tool_calls_list: list[dict] = []
|
||||
sequence: list[dict] = []
|
||||
answer = message.answer or ""
|
||||
|
||||
# Check if this is Agent-Chat mode by looking for agent thoughts
|
||||
agent_thoughts = (
|
||||
session.query(MessageAgentThought)
|
||||
.filter_by(message_id=message.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if agent_thoughts:
|
||||
# Agent-Chat mode: merge MessageAgentThought records
|
||||
content_pos = 0
|
||||
cleaned_answer_parts: list[str] = []
|
||||
for thought in agent_thoughts:
|
||||
# Add thought/reasoning
|
||||
if thought.thought:
|
||||
reasoning_text = thought.thought
|
||||
if "<think" in reasoning_text.lower():
|
||||
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_text = extracted_reasoning
|
||||
thought.thought = clean_text or extracted_reasoning
|
||||
reasoning_list.append(reasoning_text)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
|
||||
# Add tool calls
|
||||
if thought.tool:
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"name": thought.tool,
|
||||
"arguments": thought.tool_input or "",
|
||||
"result": thought.observation or "",
|
||||
}
|
||||
)
|
||||
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
|
||||
|
||||
# Add answer content if present
|
||||
if thought.answer:
|
||||
content_text = thought.answer
|
||||
if "<think" in content_text.lower():
|
||||
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
|
||||
if extracted_reasoning:
|
||||
reasoning_list.append(extracted_reasoning)
|
||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||
content_text = clean_answer
|
||||
thought.answer = clean_answer or content_text
|
||||
|
||||
if content_text:
|
||||
start = content_pos
|
||||
end = content_pos + len(content_text)
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_pos = end
|
||||
cleaned_answer_parts.append(content_text)
|
||||
|
||||
if cleaned_answer_parts:
|
||||
merged_answer = "".join(cleaned_answer_parts)
|
||||
message.answer = merged_answer
|
||||
llm_result.message.content = merged_answer
|
||||
else:
|
||||
# Completion/Chat mode: use reasoning_content from llm_result
|
||||
reasoning_content = llm_result.reasoning_content
|
||||
if not reasoning_content and answer:
|
||||
# Extract reasoning from <think> blocks and clean the final answer
|
||||
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
|
||||
if reasoning_content:
|
||||
answer = clean_answer
|
||||
llm_result.message.content = clean_answer
|
||||
llm_result.reasoning_content = reasoning_content
|
||||
message.answer = clean_answer
|
||||
if reasoning_content:
|
||||
reasoning_list = [reasoning_content]
|
||||
# Content comes first, then reasoning
|
||||
if answer:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(answer)})
|
||||
sequence.append({"type": "reasoning", "index": 0})
|
||||
|
||||
# Only save if there's meaningful generation detail
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
# Check if generation detail already exists
|
||||
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
|
||||
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
existing.sequence = json.dumps(sequence) if sequence else None
|
||||
else:
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
|
||||
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
|
||||
sequence=json.dumps(sequence) if sequence else None,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
|
||||
"""
|
||||
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
|
||||
"""
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent):
|
||||
"""
|
||||
Handle stop.
|
||||
|
|
|
|||
|
|
@ -232,12 +232,25 @@ class MessageCycleManager:
|
|||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
chunk_type: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_arguments: str | None = None,
|
||||
tool_files: list[str] | None = None,
|
||||
tool_error: str | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:param from_variable_selector: from variable selector
|
||||
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
|
||||
:param tool_call_id: unique identifier for this tool call
|
||||
:param tool_name: name of the tool being called
|
||||
:param tool_arguments: accumulated tool arguments JSON
|
||||
:param tool_files: file IDs produced by tool
|
||||
:param tool_error: error message if tool failed
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
|
|
@ -245,6 +258,12 @@ class MessageCycleManager:
|
|||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
chunk_type=chunk_type,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
tool_error=tool_error,
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from sqlalchemy import select
|
|||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
|
|
@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler:
|
|||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from models import (
|
|||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
LLMGenerationDetail,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
|
|
@ -457,6 +458,113 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
session.merge(db_model)
|
||||
session.flush()
|
||||
|
||||
# Save LLMGenerationDetail for LLM nodes with successful execution
|
||||
if (
|
||||
domain_model.node_type == NodeType.LLM
|
||||
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
and domain_model.outputs is not None
|
||||
):
|
||||
self._save_llm_generation_detail(session, domain_model)
|
||||
|
||||
def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save LLM generation detail for LLM nodes.
|
||||
Extracts reasoning_content, tool_calls, and sequence from outputs and metadata.
|
||||
"""
|
||||
outputs = execution.outputs or {}
|
||||
metadata = execution.metadata or {}
|
||||
|
||||
reasoning_list = self._extract_reasoning(outputs)
|
||||
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
|
||||
|
||||
if not reasoning_list and not tool_calls_list:
|
||||
return
|
||||
|
||||
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
|
||||
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
|
||||
|
||||
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
|
||||
"""Extract reasoning_content as a clean list of non-empty strings."""
|
||||
reasoning_content = outputs.get("reasoning_content")
|
||||
if isinstance(reasoning_content, str):
|
||||
trimmed = reasoning_content.strip()
|
||||
return [trimmed] if trimmed else []
|
||||
if isinstance(reasoning_content, list):
|
||||
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
|
||||
return []
|
||||
|
||||
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
|
||||
"""Extract tool call records from agent logs."""
|
||||
if not agent_log or not isinstance(agent_log, list):
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, str]] = []
|
||||
for log in agent_log:
|
||||
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
|
||||
tool_name = log_data.get("tool_name")
|
||||
if tool_name and str(tool_name).strip():
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": log_data.get("tool_call_id", ""),
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||
"result": str(log_data.get("output", "")),
|
||||
}
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
def _build_generation_sequence(
|
||||
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a simple content/reasoning/tool_call sequence."""
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if text:
|
||||
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
||||
for index in range(len(reasoning_list)):
|
||||
sequence.append({"type": "reasoning", "index": index})
|
||||
for index in range(len(tool_calls_list)):
|
||||
sequence.append({"type": "tool_call", "index": index})
|
||||
return sequence
|
||||
|
||||
def _upsert_generation_detail(
|
||||
self,
|
||||
session,
|
||||
execution: WorkflowNodeExecution,
|
||||
reasoning_list: list[str],
|
||||
tool_calls_list: list[dict[str, str]],
|
||||
sequence: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or update LLMGenerationDetail with serialized fields."""
|
||||
existing = (
|
||||
session.query(LLMGenerationDetail)
|
||||
.filter_by(
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
|
||||
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||
sequence_json = json.dumps(sequence) if sequence else None
|
||||
|
||||
if existing:
|
||||
existing.reasoning_content = reasoning_json
|
||||
existing.tool_calls = tool_calls_json
|
||||
existing.sequence = sequence_json
|
||||
return
|
||||
|
||||
generation_detail = LLMGenerationDetail(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_run_id=execution.workflow_execution_id,
|
||||
node_id=execution.node_id,
|
||||
reasoning_content=reasoning_json,
|
||||
tool_calls=tool_calls_json,
|
||||
sequence=sequence_json,
|
||||
)
|
||||
session.add(generation_detail)
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
|
|||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
|
|
@ -152,6 +153,60 @@ class Tool(ABC):
|
|||
|
||||
return parameters
|
||||
|
||||
def to_prompt_message_tool(self) -> PromptMessageTool:
|
||||
message_tool = PromptMessageTool(
|
||||
name=self.entity.identity.name,
|
||||
description=self.entity.description.llm if self.entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
|
||||
parameters = self.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
# Determine the description based on parameter type
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file_format_desc = " Input the file id with format: [File: file_id]."
|
||||
else:
|
||||
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": "string",
|
||||
"description": (parameter.llm_description or "") + file_format_desc,
|
||||
}
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"ToolCall",
|
||||
"ToolCallResult",
|
||||
"ToolResult",
|
||||
"ToolResultStatus",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
|
|
@ -247,6 +247,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
|
||||
LLM_TRACE = "llm_trace"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
|
@ -321,11 +327,24 @@ class ResponseStreamCoordinator:
|
|||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
chunk_type: ChunkType = ChunkType.TEXT,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_result: ToolResult | None = None,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
|
||||
Args:
|
||||
node_id: The node ID to attribute the event to
|
||||
execution_id: The execution ID for this node
|
||||
selector: The variable selector
|
||||
chunk: The chunk content
|
||||
is_final: Whether this is the final chunk
|
||||
chunk_type: The semantic type of the chunk being streamed
|
||||
tool_call: Structured data for tool_call chunks
|
||||
tool_result: Structured data for tool_result chunks
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
|
|
@ -338,6 +357,9 @@ class ResponseStreamCoordinator:
|
|||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
|
|
@ -349,6 +371,9 @@ class ResponseStreamCoordinator:
|
|||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
|
|
@ -356,6 +381,8 @@ class ResponseStreamCoordinator:
|
|||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
|
||||
For object-type variables, automatically streams all child fields that have stream events.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
|
|
@ -364,60 +391,81 @@ class ResponseStreamCoordinator:
|
|||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
active_session = self._active_session
|
||||
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
|
||||
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
# Check if there's a direct stream for this selector
|
||||
has_direct_stream = (
|
||||
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
||||
)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
if stream_targets:
|
||||
all_complete = True
|
||||
|
||||
for target_selector in stream_targets:
|
||||
while self._has_unread_stream(target_selector):
|
||||
if event := self._pop_stream_chunk(target_selector):
|
||||
events.append(
|
||||
self._rewrite_stream_event(
|
||||
event=event,
|
||||
output_node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
special_selector=bool(special_selector),
|
||||
)
|
||||
)
|
||||
|
||||
if not self._is_stream_closed(target_selector):
|
||||
all_complete = False
|
||||
|
||||
is_complete = all_complete
|
||||
|
||||
# Fallback: check if scalar value exists in variable pool
|
||||
if not is_complete and not has_direct_stream:
|
||||
if value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session
|
||||
and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _rewrite_stream_event(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent,
|
||||
output_node_id: str,
|
||||
execution_id: str,
|
||||
special_selector: bool,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Rewrite event to attribute to active response node when selector is special."""
|
||||
if not special_selector:
|
||||
return event
|
||||
|
||||
return self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
|
|
@ -513,6 +561,36 @@ class ResponseStreamCoordinator:
|
|||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
|
||||
"""Find all child stream selectors that are descendants of the parent selector.
|
||||
|
||||
For example, if parent_selector is ['llm', 'generation'], this will find:
|
||||
- ['llm', 'generation', 'content']
|
||||
- ['llm', 'generation', 'tool_calls']
|
||||
- ['llm', 'generation', 'tool_results']
|
||||
- ['llm', 'generation', 'thought']
|
||||
|
||||
Args:
|
||||
parent_selector: The parent selector to search for children
|
||||
|
||||
Returns:
|
||||
List of child selector tuples found in stream buffers or closed streams
|
||||
"""
|
||||
parent_key = tuple(parent_selector)
|
||||
parent_len = len(parent_key)
|
||||
child_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Search in both active buffers and closed streams
|
||||
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
|
||||
|
||||
for selector_key in all_selectors:
|
||||
# Check if this selector is a direct child of the parent
|
||||
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
|
||||
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
|
||||
child_streams.add(selector_key)
|
||||
|
||||
return sorted(child_streams)
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from .loop import (
|
|||
|
||||
# Node events
|
||||
from .node import (
|
||||
ChunkType,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
|
|
@ -44,10 +45,13 @@ from .node import (
|
|||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"ChunkType",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
|
|
@ -73,4 +77,6 @@ __all__ = [
|
|||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"ToolCall",
|
||||
"ToolResult",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
|
|
@ -21,13 +22,37 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
|||
provider_id: str = ""
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
"""Stream chunk event for workflow node execution."""
|
||||
|
||||
# Base fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call: ToolCall | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_call chunks",
|
||||
)
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_result: ToolResult | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_result chunks",
|
||||
)
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
|
|
|||
|
|
@ -13,16 +13,21 @@ from .loop import (
|
|||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
ChunkType,
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"ChunkType",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
|
|
@ -39,4 +44,7 @@ __all__ = [
|
|||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
"StreamCompletedEvent",
|
||||
"ThoughtChunkEvent",
|
||||
"ToolCallChunkEvent",
|
||||
"ToolResultChunkEvent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
|
@ -32,13 +34,46 @@ class RunRetryEvent(NodeEventBase):
|
|||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
|
||||
|
||||
class StreamChunkEvent(NodeEventBase):
|
||||
# Spec-compliant fields
|
||||
"""Base stream chunk event - normal text streaming output."""
|
||||
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
|
||||
|
||||
|
||||
class ToolCallChunkEvent(StreamChunkEvent):
|
||||
"""Tool call streaming event - tool call arguments streaming output."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
|
||||
|
||||
|
||||
class ToolResultChunkEvent(StreamChunkEvent):
|
||||
"""Tool result event - tool execution result."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
|
||||
|
||||
|
||||
class ThoughtChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
|
||||
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
|
|
|
|||
|
|
@ -46,6 +46,9 @@ from core.workflow.node_events import (
|
|||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -543,6 +546,8 @@ class Node(Generic[NodeDataT]):
|
|||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
|
|
@ -550,6 +555,65 @@ class Node(Generic[NodeDataT]):
|
|||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call=event.tool_call,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.entities import ToolResult, ToolResultStatus
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
tool_result = event.tool_result
|
||||
status: ToolResultStatus = (
|
||||
tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS
|
||||
)
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_result=ToolResult(
|
||||
id=tool_result.id if tool_result else None,
|
||||
name=tool_result.name if tool_result else None,
|
||||
output=tool_result.output if tool_result else None,
|
||||
files=tool_result.files if tool_result else [],
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.THOUGHT,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from .entities import (
|
|||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
ToolMetadata,
|
||||
VisionConfig,
|
||||
)
|
||||
from .node import LLMNode
|
||||
|
|
@ -13,5 +14,6 @@ __all__ = [
|
|||
"LLMNodeCompletionModelPromptTemplate",
|
||||
"LLMNodeData",
|
||||
"ModelConfig",
|
||||
"ToolMetadata",
|
||||
"VisionConfig",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCallResult
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
|
@ -58,6 +65,235 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
|||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class ToolMetadata(BaseModel):
|
||||
"""
|
||||
Tool metadata for LLM node with tool support.
|
||||
|
||||
Defines the essential fields needed for tool configuration,
|
||||
particularly the 'type' field to identify tool provider type.
|
||||
"""
|
||||
|
||||
# Core fields
|
||||
enabled: bool = True
|
||||
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
|
||||
provider_name: str = Field(..., description="Tool provider name/identifier")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
|
||||
# Optional fields
|
||||
plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools")
|
||||
credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication")
|
||||
|
||||
# Configuration fields
|
||||
parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
|
||||
settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration")
|
||||
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
|
||||
|
||||
|
||||
class LLMTraceSegment(BaseModel):
|
||||
"""
|
||||
Streaming trace segment for LLM tool-enabled runs.
|
||||
|
||||
Order is preserved for replay. Tool calls are single entries containing both
|
||||
arguments and results.
|
||||
"""
|
||||
|
||||
type: Literal["thought", "content", "tool_call"]
|
||||
|
||||
# Common optional fields
|
||||
text: str | None = Field(None, description="Text chunk for thought/content")
|
||||
|
||||
# Tool call fields (combined start + result)
|
||||
tool_call: ToolCallResult | None = Field(
|
||||
default=None,
|
||||
description="Combined tool call arguments and result for this segment",
|
||||
)
|
||||
|
||||
|
||||
class LLMGenerationData(BaseModel):
|
||||
"""Generation data from LLM invocation with tools.
|
||||
|
||||
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
|
||||
- reasoning_contents: [thought1, thought2, ...] - one element per turn
|
||||
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
|
||||
"""
|
||||
|
||||
text: str = Field(..., description="Accumulated text content from all turns")
|
||||
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
|
||||
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
|
||||
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
|
||||
usage: LLMUsage = Field(..., description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
"""Lightweight state machine to split streaming chunks by <think> tags."""
|
||||
|
||||
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
|
||||
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
|
||||
_START_PREFIX = "<think"
|
||||
_END_PREFIX = "</think"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_think = False
|
||||
|
||||
@staticmethod
|
||||
def _suffix_prefix_len(text: str, prefix: str) -> int:
|
||||
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
|
||||
max_len = min(len(text), len(prefix) - 1)
|
||||
for i in range(max_len, 0, -1):
|
||||
if text[-i:].lower() == prefix[:i].lower():
|
||||
return i
|
||||
return 0
|
||||
|
||||
def process(self, chunk: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Split incoming chunk into ('thought' | 'text', content) tuples.
|
||||
Content excludes the <think> tags themselves and handles split tags across chunks.
|
||||
"""
|
||||
parts: list[tuple[str, str]] = []
|
||||
self._buffer += chunk
|
||||
|
||||
while self._buffer:
|
||||
if self._in_think:
|
||||
end_match = self._END_PATTERN.search(self._buffer)
|
||||
if end_match:
|
||||
thought_text = self._buffer[: end_match.start()]
|
||||
if thought_text:
|
||||
parts.append(("thought", thought_text))
|
||||
self._buffer = self._buffer[end_match.end() :]
|
||||
self._in_think = False
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("thought", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
start_match = self._START_PATTERN.search(self._buffer)
|
||||
if start_match:
|
||||
prefix = self._buffer[: start_match.start()]
|
||||
if prefix:
|
||||
parts.append(("text", prefix))
|
||||
self._buffer = self._buffer[start_match.end() :]
|
||||
self._in_think = True
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("text", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
cleaned_parts: list[tuple[str, str]] = []
|
||||
for kind, content in parts:
|
||||
# Extra safeguard: strip any stray tags that slipped through.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
if content:
|
||||
cleaned_parts.append((kind, content))
|
||||
|
||||
return cleaned_parts
|
||||
|
||||
def flush(self) -> list[tuple[str, str]]:
|
||||
"""Flush remaining buffer when the stream ends."""
|
||||
if not self._buffer:
|
||||
return []
|
||||
kind = "thought" if self._in_think else "text"
|
||||
content = self._buffer
|
||||
# Drop dangling partial tags instead of emitting them
|
||||
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
|
||||
content = ""
|
||||
self._buffer = ""
|
||||
if not content:
|
||||
return []
|
||||
# Strip any complete tags that might still be present.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
return [(kind, content)] if content else []
|
||||
|
||||
|
||||
class StreamBuffers(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
|
||||
pending_thought: list[str] = Field(default_factory=list)
|
||||
pending_content: list[str] = Field(default_factory=list)
|
||||
current_turn_reasoning: list[str] = Field(default_factory=list)
|
||||
reasoning_per_turn: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TraceState(BaseModel):
|
||||
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
|
||||
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
|
||||
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AggregatedResult(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
text: str = ""
|
||||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
|
||||
agent_result: AgentResult | None = None
|
||||
|
||||
|
||||
class ToolOutputState(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
stream: StreamBuffers = Field(default_factory=StreamBuffers)
|
||||
trace: TraceState = Field(default_factory=TraceState)
|
||||
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
|
||||
agent: AgentContext = Field(default_factory=AgentContext)
|
||||
|
||||
|
||||
class ToolLogPayload(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
tool_name: str = ""
|
||||
tool_call_id: str = ""
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_output: Any = None
|
||||
tool_error: Any = None
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
meta: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
|
||||
data = log.data or {}
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
|
|
@ -86,6 +322,10 @@ class LLMNodeData(BaseNodeData):
|
|||
),
|
||||
)
|
||||
|
||||
# Tool support
|
||||
tools: Sequence[ToolMetadata] = Field(default_factory=list)
|
||||
max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, Any, Literal
|
|||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
|
|
@ -46,7 +48,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
|
|
@ -56,7 +60,8 @@ from core.variables import (
|
|||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.entities.tool_entities import ToolCallResult
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
|
|
@ -64,12 +69,16 @@ from core.workflow.enums import (
|
|||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
|
@ -81,10 +90,19 @@ from models.model import UploadFile
|
|||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
AgentContext,
|
||||
AggregatedResult,
|
||||
LLMGenerationData,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
LLMTraceSegment,
|
||||
ModelConfig,
|
||||
StreamBuffers,
|
||||
ThinkTagStreamParser,
|
||||
ToolLogPayload,
|
||||
ToolOutputState,
|
||||
TraceState,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
|
|
@ -149,11 +167,11 @@ class LLMNode(Node[LLMNodeData]):
|
|||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
result_text = ""
|
||||
clean_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
reasoning_content = "" # Initialize as empty string for consistency
|
||||
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
|
|
@ -234,55 +252,58 @@ class LLMNode(Node[LLMNodeData]):
|
|||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
)
|
||||
|
||||
# Variables for outputs
|
||||
generation_data: LLMGenerationData | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, StreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
# Raw text
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
reasoning_content = event.reasoning_content or ""
|
||||
# Check if tools are configured
|
||||
if self.tool_call_enabled:
|
||||
# Use tool-enabled invocation (Agent V2 style)
|
||||
generator = self._invoke_llm_with_tools(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
files=files,
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
else:
|
||||
# Use traditional LLM invocation
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self._node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
)
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self.node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
|
||||
(
|
||||
clean_text,
|
||||
reasoning_content,
|
||||
generation_reasoning_content,
|
||||
generation_clean_content,
|
||||
usage,
|
||||
finish_reason,
|
||||
structured_output,
|
||||
generation_data,
|
||||
) = yield from self._stream_llm_events(generator, model_instance=model_instance)
|
||||
|
||||
# Process structured output if available from the event.
|
||||
structured_output = (
|
||||
LLMStructuredOutput(structured_output=event.structured_output)
|
||||
if event.structured_output
|
||||
else None
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
# Extract variables from generation_data if available
|
||||
if generation_data:
|
||||
clean_text = generation_data.text
|
||||
reasoning_content = ""
|
||||
usage = generation_data.usage
|
||||
finish_reason = generation_data.finish_reason
|
||||
|
||||
# Unified process_data building
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
|
|
@ -293,24 +314,88 @@ class LLMNode(Node[LLMNodeData]):
|
|||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
if self.tool_call_enabled and self._node_data.tools:
|
||||
process_data["tools"] = [
|
||||
{
|
||||
"type": tool.type.value if hasattr(tool.type, "value") else tool.type,
|
||||
"provider_name": tool.provider_name,
|
||||
"tool_name": tool.tool_name,
|
||||
}
|
||||
for tool in self._node_data.tools
|
||||
if tool.enabled
|
||||
]
|
||||
|
||||
# Unified outputs building
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
# Build generation field
|
||||
if generation_data:
|
||||
# Use generation_data from tool invocation (supports multi-turn)
|
||||
generation = {
|
||||
"content": generation_data.text,
|
||||
"reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...]
|
||||
"tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls],
|
||||
"sequence": generation_data.sequence,
|
||||
}
|
||||
files_to_output = generation_data.files
|
||||
else:
|
||||
# Traditional LLM invocation
|
||||
generation_reasoning = generation_reasoning_content or reasoning_content
|
||||
generation_content = generation_clean_content or clean_text
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if generation_reasoning:
|
||||
sequence = [
|
||||
{"type": "reasoning", "index": 0},
|
||||
{"type": "content", "start": 0, "end": len(generation_content)},
|
||||
]
|
||||
generation = {
|
||||
"content": generation_content,
|
||||
"reasoning_content": [generation_reasoning] if generation_reasoning else [],
|
||||
"tool_calls": [],
|
||||
"sequence": sequence,
|
||||
}
|
||||
files_to_output = self._file_outputs
|
||||
|
||||
outputs["generation"] = generation
|
||||
if files_to_output:
|
||||
outputs["files"] = ArrayFileSegment(value=files_to_output)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
if not self.tool_call_enabled:
|
||||
# For tool calls, final events are already sent in _process_tool_outputs
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
}
|
||||
|
||||
if generation_data and generation_data.trace:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [
|
||||
segment.model_dump() for segment in generation_data.trace
|
||||
]
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
|
|
@ -318,11 +403,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
metadata=metadata,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
|
|
@ -444,6 +525,8 @@ class LLMNode(Node[LLMNodeData]):
|
|||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
think_parser = ThinkTagStreamParser()
|
||||
reasoning_chunks: list[str] = []
|
||||
|
||||
# Initialize streaming metrics tracking
|
||||
start_time = request_start_time if request_start_time is not None else time.perf_counter()
|
||||
|
|
@ -472,12 +555,32 @@ class LLMNode(Node[LLMNodeData]):
|
|||
has_content = True
|
||||
|
||||
full_text_buffer.write(text_part)
|
||||
# Text output: always forward raw chunk (keep <think> tags intact)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=text_part,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Generation output: split out thoughts, forward only non-thought content chunks
|
||||
for kind, segment in think_parser.process(text_part):
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if kind == "thought":
|
||||
reasoning_chunks.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
|
|
@ -492,16 +595,35 @@ class LLMNode(Node[LLMNodeData]):
|
|||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
for kind, segment in think_parser.flush():
|
||||
if not segment:
|
||||
continue
|
||||
if kind == "thought":
|
||||
reasoning_chunks.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
reasoning_content = "".join(reasoning_chunks)
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
if reasoning_chunks and not reasoning_content:
|
||||
reasoning_content = "".join(reasoning_chunks)
|
||||
|
||||
# Calculate streaming metrics
|
||||
end_time = time.perf_counter()
|
||||
|
|
@ -1266,6 +1388,635 @@ class LLMNode(Node[LLMNodeData]):
|
|||
def retry(self) -> bool:
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@property
|
||||
def tool_call_enabled(self) -> bool:
|
||||
return (
|
||||
self.node_data.tools is not None
|
||||
and len(self.node_data.tools) > 0
|
||||
and all(tool.enabled for tool in self.node_data.tools)
|
||||
)
|
||||
|
||||
def _stream_llm_events(
|
||||
self,
|
||||
generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None],
|
||||
*,
|
||||
model_instance: ModelInstance,
|
||||
) -> Generator[
|
||||
NodeEventBase,
|
||||
None,
|
||||
tuple[
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
LLMUsage,
|
||||
str | None,
|
||||
LLMStructuredOutput | None,
|
||||
LLMGenerationData | None,
|
||||
],
|
||||
]:
|
||||
"""
|
||||
Stream events and capture generator return value in one place.
|
||||
Uses generator delegation so _run stays concise while still emitting events.
|
||||
"""
|
||||
clean_text = ""
|
||||
reasoning_content = ""
|
||||
generation_reasoning_content = ""
|
||||
generation_clean_content = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason: str | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
generation_data: LLMGenerationData | None = None
|
||||
completed = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
event = next(generator)
|
||||
except StopIteration as exc:
|
||||
if isinstance(exc.value, LLMGenerationData):
|
||||
generation_data = exc.value
|
||||
break
|
||||
|
||||
if completed:
|
||||
# After completion we still drain to reach StopIteration.value
|
||||
continue
|
||||
|
||||
match event:
|
||||
case StreamChunkEvent() | ThoughtChunkEvent():
|
||||
yield event
|
||||
|
||||
case ModelInvokeCompletedEvent(
|
||||
text=text,
|
||||
usage=usage_event,
|
||||
finish_reason=finish_reason_event,
|
||||
reasoning_content=reasoning_event,
|
||||
structured_output=structured_raw,
|
||||
):
|
||||
clean_text = text
|
||||
usage = usage_event
|
||||
finish_reason = finish_reason_event
|
||||
reasoning_content = reasoning_event or ""
|
||||
generation_reasoning_content = reasoning_content
|
||||
generation_clean_content = clean_text
|
||||
|
||||
if self.node_data.reasoning_format == "tagged":
|
||||
# Keep tagged text for output; also extract reasoning for generation field
|
||||
generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning(
|
||||
clean_text, reasoning_format="separated"
|
||||
)
|
||||
else:
|
||||
clean_text, generation_reasoning_content = LLMNode._split_reasoning(
|
||||
clean_text, self.node_data.reasoning_format
|
||||
)
|
||||
generation_clean_content = clean_text
|
||||
|
||||
structured_output = (
|
||||
LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None
|
||||
)
|
||||
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
completed = True
|
||||
|
||||
case LLMStructuredOutput():
|
||||
structured_output = event
|
||||
|
||||
case _:
|
||||
continue
|
||||
|
||||
return (
|
||||
clean_text,
|
||||
reasoning_content,
|
||||
generation_reasoning_content,
|
||||
generation_clean_content,
|
||||
usage,
|
||||
finish_reason,
|
||||
structured_output,
|
||||
generation_data,
|
||||
)
|
||||
|
||||
def _invoke_llm_with_tools(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
files: Sequence["File"],
|
||||
variable_pool: VariablePool,
|
||||
node_inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Invoke LLM with tools support (from Agent V2).
|
||||
|
||||
Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files
|
||||
"""
|
||||
# Get model features to determine strategy
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
# Prepare tool instances
|
||||
tool_instances = self._prepare_tool_instances(variable_pool)
|
||||
|
||||
# Prepare prompt files (files that come from prompt variables, not vision files)
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=tool_instances,
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 10,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
)
|
||||
|
||||
# Run strategy
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Process outputs and return generation result
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
return result
|
||||
|
||||
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
||||
"""Get model schema to determine features."""
|
||||
try:
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_instance.model,
|
||||
model_instance.credentials,
|
||||
)
|
||||
return model_schema.features if model_schema and model_schema.features else []
|
||||
except Exception:
|
||||
logger.warning("Failed to get model schema, assuming no special features")
|
||||
return []
|
||||
|
||||
def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]:
|
||||
"""Prepare tool instances from configuration."""
|
||||
tool_instances = []
|
||||
|
||||
if self._node_data.tools:
|
||||
for tool in self._node_data.tools:
|
||||
try:
|
||||
# Process settings to extract the correct structure
|
||||
processed_settings = {}
|
||||
for key, value in tool.settings.items():
|
||||
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
|
||||
# Extract the nested value if it has the ToolInput structure
|
||||
if "type" in value["value"] and "value" in value["value"]:
|
||||
processed_settings[key] = value["value"]
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
|
||||
# Merge parameters with processed settings (similar to Agent Node logic)
|
||||
merged_parameters = {**tool.parameters, **processed_settings}
|
||||
|
||||
# Create AgentToolEntity from ToolMetadata
|
||||
agent_tool = AgentToolEntity(
|
||||
provider_id=tool.provider_name,
|
||||
provider_type=tool.type,
|
||||
tool_name=tool.tool_name,
|
||||
tool_parameters=merged_parameters,
|
||||
plugin_unique_identifier=tool.plugin_unique_identifier,
|
||||
credential_id=tool.credential_id,
|
||||
)
|
||||
|
||||
# Get tool runtime from ToolManager
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
agent_tool=agent_tool,
|
||||
invoke_from=self.invoke_from,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Apply custom description from extra field if available
|
||||
if tool.extra.get("description") and tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
tool.extra.get("description") or tool_runtime.entity.description.llm
|
||||
)
|
||||
|
||||
tool_instances.append(tool_runtime)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
||||
continue
|
||||
|
||||
return tool_instances
|
||||
|
||||
def _extract_prompt_files(self, variable_pool: VariablePool) -> list["File"]:
|
||||
"""Extract files from prompt template variables."""
|
||||
from core.variables import ArrayFileVariable, FileVariable
|
||||
|
||||
files: list[File] = []
|
||||
|
||||
# Extract variables from prompt template
|
||||
if isinstance(self._node_data.prompt_template, list):
|
||||
for message in self._node_data.prompt_template:
|
||||
if message.text:
|
||||
parser = VariableTemplateParser(message.text)
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, FileVariable) and variable.value:
|
||||
files.append(variable.value)
|
||||
elif isinstance(variable, ArrayFileVariable) and variable.value:
|
||||
files.extend(variable.value)
|
||||
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]:
|
||||
"""Convert ToolCallResult into JSON-friendly dict."""
|
||||
|
||||
def _file_to_ref(file: File) -> str | None:
|
||||
# Align with streamed tool result events which carry file IDs
|
||||
return file.id or file.related_id
|
||||
|
||||
files = []
|
||||
for file in tool_call.files or []:
|
||||
ref = _file_to_ref(file)
|
||||
if ref:
|
||||
files.append(ref)
|
||||
|
||||
return {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"arguments": tool_call.arguments,
|
||||
"output": tool_call.output,
|
||||
"files": files,
|
||||
"status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status,
|
||||
}
|
||||
|
||||
def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
|
||||
if not buffers.pending_thought:
|
||||
return
|
||||
trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought)))
|
||||
buffers.pending_thought.clear()
|
||||
|
||||
def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
|
||||
if not buffers.pending_content:
|
||||
return
|
||||
trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content)))
|
||||
buffers.pending_content.clear()
|
||||
|
||||
def _handle_agent_log_output(
|
||||
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
payload = ToolLogPayload.from_log(output)
|
||||
agent_log_event = AgentLogEvent(
|
||||
message_id=output.id,
|
||||
label=output.label,
|
||||
node_execution_id=self.id,
|
||||
parent_id=output.parent_id,
|
||||
error=output.error,
|
||||
status=output.status.value,
|
||||
data=output.data,
|
||||
metadata={k.value: v for k, v in output.metadata.items()},
|
||||
node_id=self._node_id,
|
||||
)
|
||||
for log in agent_context.agent_logs:
|
||||
if log.message_id == agent_log_event.message_id:
|
||||
log.data = agent_log_event.data
|
||||
log.status = agent_log_event.status
|
||||
log.error = agent_log_event.error
|
||||
log.label = agent_log_event.label
|
||||
log.metadata = agent_log_event.metadata
|
||||
break
|
||||
else:
|
||||
agent_context.agent_logs.append(agent_log_event)
|
||||
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
|
||||
tool_name = payload.tool_name
|
||||
tool_call_id = payload.tool_call_id
|
||||
tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else ""
|
||||
|
||||
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
||||
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
||||
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
tool_call_segment = LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
)
|
||||
trace_state.trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk=tool_arguments,
|
||||
tool_call=ToolCall(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
||||
tool_name = payload.tool_name
|
||||
tool_output = payload.tool_output
|
||||
tool_call_id = payload.tool_call_id
|
||||
tool_files = payload.files if isinstance(payload.files, list) else []
|
||||
tool_error = payload.tool_error
|
||||
|
||||
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
||||
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
||||
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
if output.status == AgentLog.LogStatus.ERROR:
|
||||
tool_error = output.error or payload.tool_error
|
||||
if not tool_error and payload.meta:
|
||||
tool_error = payload.meta.get("error")
|
||||
else:
|
||||
if payload.meta:
|
||||
meta_error = payload.meta.get("error")
|
||||
if meta_error:
|
||||
tool_error = meta_error
|
||||
|
||||
existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id)
|
||||
tool_call_segment = existing_tool_segment or LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
if existing_tool_segment is None:
|
||||
trace_state.trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
if tool_call_segment.tool_call is None:
|
||||
tool_call_segment.tool_call = ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
)
|
||||
tool_call_segment.tool_call.output = (
|
||||
str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None
|
||||
)
|
||||
tool_call_segment.tool_call.files = []
|
||||
tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
|
||||
|
||||
result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
|
||||
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk=result_output or "",
|
||||
tool_result=ToolResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
output=result_output,
|
||||
files=tool_files,
|
||||
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if buffers.current_turn_reasoning:
|
||||
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
||||
buffers.current_turn_reasoning.clear()
|
||||
|
||||
def _handle_llm_chunk_output(
|
||||
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
message = output.delta.message
|
||||
|
||||
if message and message.content:
|
||||
chunk_text = message.content
|
||||
if isinstance(chunk_text, list):
|
||||
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
|
||||
else:
|
||||
chunk_text = str(chunk_text)
|
||||
|
||||
for kind, segment in buffers.think_parser.process(chunk_text):
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if kind == "thought":
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
buffers.current_turn_reasoning.append(segment)
|
||||
buffers.pending_thought.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
aggregate.text += segment
|
||||
buffers.pending_content.append(segment)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if output.delta.usage:
|
||||
self._accumulate_usage(aggregate.usage, output.delta.usage)
|
||||
|
||||
if output.delta.finish_reason:
|
||||
aggregate.finish_reason = output.delta.finish_reason
|
||||
|
||||
def _flush_remaining_stream(
|
||||
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
for kind, segment in buffers.think_parser.flush():
|
||||
if not segment:
|
||||
continue
|
||||
if kind == "thought":
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
buffers.current_turn_reasoning.append(segment)
|
||||
buffers.pending_thought.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
aggregate.text += segment
|
||||
buffers.pending_content.append(segment)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if buffers.current_turn_reasoning:
|
||||
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
||||
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(
|
||||
id="",
|
||||
name="",
|
||||
arguments="",
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk="",
|
||||
tool_result=ToolResult(
|
||||
id="",
|
||||
name="",
|
||||
output="",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
def _build_generation_data(
|
||||
self,
|
||||
trace_state: TraceState,
|
||||
agent_context: AgentContext,
|
||||
aggregate: AggregatedResult,
|
||||
buffers: StreamBuffers,
|
||||
) -> LLMGenerationData:
|
||||
sequence: list[dict[str, Any]] = []
|
||||
reasoning_index = 0
|
||||
content_position = 0
|
||||
tool_call_seen_index: dict[str, int] = {}
|
||||
for trace_segment in trace_state.trace_segments:
|
||||
if trace_segment.type == "thought":
|
||||
sequence.append({"type": "reasoning", "index": reasoning_index})
|
||||
reasoning_index += 1
|
||||
elif trace_segment.type == "content":
|
||||
segment_text = trace_segment.text or ""
|
||||
start = content_position
|
||||
end = start + len(segment_text)
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_position = end
|
||||
elif trace_segment.type == "tool_call":
|
||||
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
|
||||
if tool_id not in tool_call_seen_index:
|
||||
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
|
||||
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
|
||||
|
||||
tool_calls_for_generation: list[ToolCallResult] = []
|
||||
for log in agent_context.agent_logs:
|
||||
payload = ToolLogPayload.from_mapping(log.data or {})
|
||||
tool_call_id = payload.tool_call_id
|
||||
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
|
||||
continue
|
||||
|
||||
tool_args = payload.tool_args
|
||||
log_error = payload.tool_error
|
||||
log_output = payload.tool_output
|
||||
result_text = log_output or log_error or ""
|
||||
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
|
||||
tool_calls_for_generation.append(
|
||||
ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=payload.tool_name,
|
||||
arguments=json.dumps(tool_args) if tool_args else "",
|
||||
output=result_text,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
tool_calls_for_generation.sort(
|
||||
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
|
||||
)
|
||||
|
||||
return LLMGenerationData(
|
||||
text=aggregate.text,
|
||||
reasoning_contents=buffers.reasoning_per_turn,
|
||||
tool_calls=tool_calls_for_generation,
|
||||
sequence=sequence,
|
||||
usage=aggregate.usage,
|
||||
finish_reason=aggregate.finish_reason,
|
||||
files=aggregate.files,
|
||||
trace=trace_state.trace_segments,
|
||||
)
|
||||
|
||||
def _process_tool_outputs(
|
||||
self,
|
||||
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Process strategy outputs and convert to node events."""
|
||||
state = ToolOutputState()
|
||||
|
||||
try:
|
||||
for output in outputs:
|
||||
if isinstance(output, AgentLog):
|
||||
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
||||
else:
|
||||
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
||||
except StopIteration as exception:
|
||||
if isinstance(getattr(exception, "value", None), AgentResult):
|
||||
state.agent.agent_result = exception.value
|
||||
|
||||
if state.agent.agent_result:
|
||||
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
state.aggregate.usage = state.agent.agent_result.usage
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
|
||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||
yield from self._close_streams()
|
||||
|
||||
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
|
||||
|
||||
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
total_usage.prompt_tokens += delta_usage.prompt_tokens
|
||||
total_usage.completion_tokens += delta_usage.completion_tokens
|
||||
total_usage.total_tokens += delta_usage.total_tokens
|
||||
total_usage.prompt_price += delta_usage.prompt_price
|
||||
total_usage.completion_price += delta_usage.completion_price
|
||||
total_usage.total_price += delta_usage.total_price
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ message_detail_fields = {
|
|||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
}
|
||||
|
||||
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ message_fields = {
|
|||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"generation_detail": fields.Raw,
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ workflow_run_detail_fields = {
|
|||
"inputs": fields.Raw(attribute="inputs_dict"),
|
||||
"status": fields.String,
|
||||
"outputs": fields.Raw(attribute="outputs_dict"),
|
||||
"outputs_as_generation": fields.Boolean,
|
||||
"error": fields.String,
|
||||
"elapsed_time": fields.Float,
|
||||
"total_tokens": fields.Integer,
|
||||
|
|
@ -129,6 +130,7 @@ workflow_run_node_execution_fields = {
|
|||
"inputs_truncated": fields.Boolean,
|
||||
"outputs_truncated": fields.Boolean,
|
||||
"process_data_truncated": fields.Boolean,
|
||||
"generation_detail": fields.Raw,
|
||||
}
|
||||
|
||||
workflow_run_node_execution_list_fields = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
"""add llm generation detail table.
|
||||
|
||||
Revision ID: 85c8b4a64f53
|
||||
Revises: 7bb281b7a422
|
||||
Create Date: 2025-12-10 16:17:46.597669
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '85c8b4a64f53'
|
||||
down_revision = '03ea244985ce'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('llm_generation_details',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('message_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('reasoning_content', models.types.LongText(), nullable=True),
|
||||
sa.Column('tool_calls', models.types.LongText(), nullable=True),
|
||||
sa.Column('sequence', models.types.LongText(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.CheckConstraint('(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL) OR (message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)', name=op.f('llm_generation_details_ck_llm_generation_detail_assoc_mode_check')),
|
||||
sa.PrimaryKeyConstraint('id', name='llm_generation_detail_pkey'),
|
||||
sa.UniqueConstraint('message_id', name=op.f('llm_generation_details_message_id_key'))
|
||||
)
|
||||
with op.batch_alter_table('llm_generation_details', schema=None) as batch_op:
|
||||
batch_op.create_index('idx_llm_generation_detail_message', ['message_id'], unique=False)
|
||||
batch_op.create_index('idx_llm_generation_detail_workflow', ['workflow_run_id', 'node_id'], unique=False)
|
||||
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('llm_generation_details')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -49,6 +49,7 @@ from .model import (
|
|||
EndUser,
|
||||
IconType,
|
||||
InstalledApp,
|
||||
LLMGenerationDetail,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
|
|
@ -155,6 +156,7 @@ __all__ = [
|
|||
"IconType",
|
||||
"InstalledApp",
|
||||
"InvitationCode",
|
||||
"LLMGenerationDetail",
|
||||
"LoadBalancingModelConfig",
|
||||
"Message",
|
||||
"MessageAgentThought",
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ from .provider_ids import GenericProviderID
|
|||
from .types import LongText, StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.entities.llm_generation_entities import LLMGenerationDetailData
|
||||
|
||||
from .workflow import Workflow
|
||||
|
||||
|
||||
|
|
@ -1201,6 +1203,18 @@ class Message(Base):
|
|||
.all()
|
||||
)
|
||||
|
||||
# FIXME (Novice) -- It's easy to cause N+1 query problem here.
|
||||
@property
|
||||
def generation_detail(self) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get LLM generation detail for this message.
|
||||
Returns the detail as a dictionary or None if not found.
|
||||
"""
|
||||
detail = db.session.query(LLMGenerationDetail).filter_by(message_id=self.id).first()
|
||||
if detail:
|
||||
return detail.to_dict()
|
||||
return None
|
||||
|
||||
@property
|
||||
def retriever_resources(self) -> Any:
|
||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
|
@ -2091,3 +2105,87 @@ class TenantCreditPool(Base):
|
|||
|
||||
def has_sufficient_credits(self, required_credits: int) -> bool:
|
||||
return self.remaining_credits >= required_credits
|
||||
|
||||
|
||||
class LLMGenerationDetail(Base):
|
||||
"""
|
||||
Store LLM generation details including reasoning process and tool calls.
|
||||
|
||||
Association (choose one):
|
||||
- For apps with Message: use message_id (one-to-one)
|
||||
- For Workflow: use workflow_run_id + node_id (one run may have multiple LLM nodes)
|
||||
"""
|
||||
|
||||
__tablename__ = "llm_generation_details"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="llm_generation_detail_pkey"),
|
||||
sa.Index("idx_llm_generation_detail_message", "message_id"),
|
||||
sa.Index("idx_llm_generation_detail_workflow", "workflow_run_id", "node_id"),
|
||||
sa.CheckConstraint(
|
||||
"(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL)"
|
||||
" OR "
|
||||
"(message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)",
|
||||
name="ck_llm_generation_detail_assoc_mode",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# Association fields (choose one)
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, unique=True)
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
node_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# Core data as JSON strings
|
||||
reasoning_content: Mapped[str | None] = mapped_column(LongText)
|
||||
tool_calls: Mapped[str | None] = mapped_column(LongText)
|
||||
sequence: Mapped[str | None] = mapped_column(LongText)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def to_domain_model(self) -> "LLMGenerationDetailData":
|
||||
"""Convert to Pydantic domain model with proper validation."""
|
||||
from core.app.entities.llm_generation_entities import LLMGenerationDetailData
|
||||
|
||||
return LLMGenerationDetailData(
|
||||
reasoning_content=json.loads(self.reasoning_content) if self.reasoning_content else [],
|
||||
tool_calls=json.loads(self.tool_calls) if self.tool_calls else [],
|
||||
sequence=json.loads(self.sequence) if self.sequence else [],
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for API response."""
|
||||
return self.to_domain_model().to_response_dict()
|
||||
|
||||
@classmethod
|
||||
def from_domain_model(
|
||||
cls,
|
||||
data: "LLMGenerationDetailData",
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
message_id: str | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
) -> "LLMGenerationDetail":
|
||||
"""Create from Pydantic domain model."""
|
||||
# Enforce association mode at object creation time as well.
|
||||
message_mode = message_id is not None
|
||||
workflow_mode = workflow_run_id is not None or node_id is not None
|
||||
if message_mode and workflow_mode:
|
||||
raise ValueError("LLMGenerationDetail cannot set both message_id and workflow_run_id/node_id.")
|
||||
if not message_mode and not (workflow_run_id and node_id):
|
||||
raise ValueError("LLMGenerationDetail requires either message_id or workflow_run_id+node_id.")
|
||||
|
||||
return cls(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=node_id,
|
||||
reasoning_content=json.dumps(data.reasoning_content) if data.reasoning_content else None,
|
||||
tool_calls=json.dumps([tc.model_dump() for tc in data.tool_calls]) if data.tool_calls else None,
|
||||
sequence=json.dumps([seg.model_dump() for seg in data.sequence]) if data.sequence else None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -57,6 +57,37 @@ from .types import EnumText, LongText, StringUUID
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_generation_outputs(outputs: Mapping[str, Any]) -> bool:
|
||||
if not outputs:
|
||||
return False
|
||||
|
||||
allowed_sequence_types = {"reasoning", "content", "tool_call"}
|
||||
|
||||
def valid_sequence_item(item: Mapping[str, Any]) -> bool:
|
||||
return isinstance(item, Mapping) and item.get("type") in allowed_sequence_types
|
||||
|
||||
def valid_value(value: Any) -> bool:
|
||||
if not isinstance(value, Mapping):
|
||||
return False
|
||||
|
||||
content = value.get("content")
|
||||
reasoning_content = value.get("reasoning_content")
|
||||
tool_calls = value.get("tool_calls")
|
||||
sequence = value.get("sequence")
|
||||
|
||||
return (
|
||||
isinstance(content, str)
|
||||
and isinstance(reasoning_content, list)
|
||||
and all(isinstance(item, str) for item in reasoning_content)
|
||||
and isinstance(tool_calls, list)
|
||||
and all(isinstance(item, Mapping) for item in tool_calls)
|
||||
and isinstance(sequence, list)
|
||||
and all(valid_sequence_item(item) for item in sequence)
|
||||
)
|
||||
|
||||
return all(valid_value(value) for value in outputs.values())
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum
|
||||
|
|
@ -664,6 +695,10 @@ class WorkflowRun(Base):
|
|||
def workflow(self):
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
|
||||
@property
|
||||
def outputs_as_generation(self):
|
||||
return is_generation_outputs(self.outputs_dict)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
|
|
@ -677,6 +712,7 @@ class WorkflowRun(Base):
|
|||
"inputs": self.inputs_dict,
|
||||
"status": self.status,
|
||||
"outputs": self.outputs_dict,
|
||||
"outputs_as_generation": self.outputs_as_generation,
|
||||
"error": self.error,
|
||||
"elapsed_time": self.elapsed_time,
|
||||
"total_tokens": self.total_tokens,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
"""
|
||||
LLM Generation Detail Service.
|
||||
|
||||
Provides methods to query and attach generation details to workflow node executions
|
||||
and messages, avoiding N+1 query problems.
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.llm_generation_entities import LLMGenerationDetailData
|
||||
from models import LLMGenerationDetail
|
||||
|
||||
|
||||
class LLMGenerationService:
|
||||
"""Service for handling LLM generation details."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None:
|
||||
"""Query generation detail for a specific message."""
|
||||
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id)
|
||||
detail = self._session.scalars(stmt).first()
|
||||
return detail.to_domain_model() if detail else None
|
||||
|
||||
def get_generation_details_for_messages(
|
||||
self,
|
||||
message_ids: list[str],
|
||||
) -> dict[str, LLMGenerationDetailData]:
|
||||
"""Batch query generation details for multiple messages."""
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids))
|
||||
details = self._session.scalars(stmt).all()
|
||||
return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id}
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
"""
|
||||
Mark agent test modules as a package to avoid import name collisions.
|
||||
"""
|
||||
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
"""Tests for AgentPattern base class."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentLog, ExecutionContext
|
||||
from core.agent.patterns.base import AgentPattern
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class ConcreteAgentPattern(AgentPattern):
|
||||
"""Concrete implementation of AgentPattern for testing."""
|
||||
|
||||
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
|
||||
"""Minimal implementation for testing."""
|
||||
yield from []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_pattern(mock_model_instance, mock_context):
|
||||
"""Create a concrete agent pattern for testing."""
|
||||
return ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
|
||||
class TestAccumulateUsage:
|
||||
"""Tests for _accumulate_usage method."""
|
||||
|
||||
def test_accumulate_usage_to_empty_dict(self, agent_pattern):
|
||||
"""Test accumulating usage to an empty dict creates a copy."""
|
||||
total_usage: dict = {"usage": None}
|
||||
delta_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
agent_pattern._accumulate_usage(total_usage, delta_usage)
|
||||
|
||||
assert total_usage["usage"] is not None
|
||||
assert total_usage["usage"].total_tokens == 150
|
||||
assert total_usage["usage"].prompt_tokens == 100
|
||||
assert total_usage["usage"].completion_tokens == 50
|
||||
# Verify it's a copy, not a reference
|
||||
assert total_usage["usage"] is not delta_usage
|
||||
|
||||
def test_accumulate_usage_adds_to_existing(self, agent_pattern):
|
||||
"""Test accumulating usage adds to existing values."""
|
||||
initial_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
total_usage: dict = {"usage": initial_usage}
|
||||
|
||||
delta_usage = LLMUsage(
|
||||
prompt_tokens=200,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.2"),
|
||||
completion_tokens=100,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.2"),
|
||||
total_tokens=300,
|
||||
total_price=Decimal("0.4"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
agent_pattern._accumulate_usage(total_usage, delta_usage)
|
||||
|
||||
assert total_usage["usage"].total_tokens == 450 # 150 + 300
|
||||
assert total_usage["usage"].prompt_tokens == 300 # 100 + 200
|
||||
assert total_usage["usage"].completion_tokens == 150 # 50 + 100
|
||||
|
||||
def test_accumulate_usage_multiple_rounds(self, agent_pattern):
|
||||
"""Test accumulating usage across multiple rounds."""
|
||||
total_usage: dict = {"usage": None}
|
||||
|
||||
# Round 1: 100 tokens
|
||||
round1_usage = LLMUsage(
|
||||
prompt_tokens=70,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.07"),
|
||||
completion_tokens=30,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.06"),
|
||||
total_tokens=100,
|
||||
total_price=Decimal("0.13"),
|
||||
currency="USD",
|
||||
latency=0.3,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round1_usage)
|
||||
assert total_usage["usage"].total_tokens == 100
|
||||
|
||||
# Round 2: 150 tokens
|
||||
round2_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.4,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round2_usage)
|
||||
assert total_usage["usage"].total_tokens == 250 # 100 + 150
|
||||
|
||||
# Round 3: 200 tokens
|
||||
round3_usage = LLMUsage(
|
||||
prompt_tokens=130,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.13"),
|
||||
completion_tokens=70,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.14"),
|
||||
total_tokens=200,
|
||||
total_price=Decimal("0.27"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round3_usage)
|
||||
assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200
|
||||
|
||||
|
||||
class TestCreateLog:
|
||||
"""Tests for _create_log method."""
|
||||
|
||||
def test_create_log_with_label_and_status(self, agent_pattern):
|
||||
"""Test creating a log with label and status."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
|
||||
assert log.label == "ROUND 1"
|
||||
assert log.log_type == AgentLog.LogType.ROUND
|
||||
assert log.status == AgentLog.LogStatus.START
|
||||
assert log.data == {"key": "value"}
|
||||
assert log.parent_id is None
|
||||
|
||||
def test_create_log_with_parent_id(self, agent_pattern):
|
||||
"""Test creating a log with parent_id."""
|
||||
parent_log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
child_log = agent_pattern._create_log(
|
||||
label="CALL tool",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=parent_log.id,
|
||||
)
|
||||
|
||||
assert child_log.parent_id == parent_log.id
|
||||
assert child_log.log_type == AgentLog.LogType.TOOL_CALL
|
||||
|
||||
|
||||
class TestFinishLog:
|
||||
"""Tests for _finish_log method."""
|
||||
|
||||
def test_finish_log_updates_status(self, agent_pattern):
|
||||
"""Test that finish_log updates status to SUCCESS."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
finished_log = agent_pattern._finish_log(log, data={"result": "done"})
|
||||
|
||||
assert finished_log.status == AgentLog.LogStatus.SUCCESS
|
||||
assert finished_log.data == {"result": "done"}
|
||||
|
||||
def test_finish_log_adds_usage_metadata(self, agent_pattern):
|
||||
"""Test that finish_log adds usage to metadata."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
finished_log = agent_pattern._finish_log(log, usage=usage)
|
||||
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2")
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD"
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage
|
||||
|
||||
|
||||
class TestFindToolByName:
|
||||
"""Tests for _find_tool_by_name method."""
|
||||
|
||||
def test_find_existing_tool(self, mock_model_instance, mock_context):
|
||||
"""Test finding an existing tool by name."""
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
found_tool = pattern._find_tool_by_name("test_tool")
|
||||
assert found_tool == mock_tool
|
||||
|
||||
def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context):
|
||||
"""Test that finding a nonexistent tool returns None."""
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
found_tool = pattern._find_tool_by_name("nonexistent_tool")
|
||||
assert found_tool is None
|
||||
|
||||
|
||||
class TestMaxIterationsCapping:
|
||||
"""Tests for max_iterations capping."""
|
||||
|
||||
def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is capped at 99."""
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=150,
|
||||
)
|
||||
|
||||
assert pattern.max_iterations == 99
|
||||
|
||||
def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is not capped when under 99."""
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=50,
|
||||
)
|
||||
|
||||
assert pattern.max_iterations == 50
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""Tests for FunctionCallStrategy."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentLog, ExecutionContext
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool."""
|
||||
tool = MagicMock()
|
||||
tool.entity.identity.name = "test_tool"
|
||||
tool.to_prompt_message_tool.return_value = PromptMessageTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"param1": {"type": "string", "description": "A parameter"}},
|
||||
"required": ["param1"],
|
||||
},
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
class TestFunctionCallStrategyInit:
|
||||
"""Tests for FunctionCallStrategy initialization."""
|
||||
|
||||
def test_initialization(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test basic initialization."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
assert strategy.model_instance == mock_model_instance
|
||||
assert strategy.context == mock_context
|
||||
assert strategy.max_iterations == 10
|
||||
assert len(strategy.tools) == 1
|
||||
|
||||
def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test initialization with tool_invoke_hook."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
mock_hook = MagicMock()
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
assert strategy.tool_invoke_hook == mock_hook
|
||||
|
||||
|
||||
class TestConvertToolsToPromptFormat:
|
||||
"""Tests for _convert_tools_to_prompt_format method."""
|
||||
|
||||
def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that _convert_tools_to_prompt_format returns PromptMessageTool list."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
tools = strategy._convert_tools_to_prompt_format()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], PromptMessageTool)
|
||||
assert tools[0].name == "test_tool"
|
||||
|
||||
def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context):
|
||||
"""Test that _convert_tools_to_prompt_format returns empty list when no tools."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
tools = strategy._convert_tools_to_prompt_format()
|
||||
|
||||
assert tools == []
|
||||
|
||||
|
||||
class TestAgentLogGeneration:
|
||||
"""Tests for AgentLog generation during run."""
|
||||
|
||||
def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that round logs have correct structure."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
# Create a round log
|
||||
round_log = strategy._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"inputs": {"query": "test"}},
|
||||
)
|
||||
|
||||
assert round_log.label == "ROUND 1"
|
||||
assert round_log.log_type == AgentLog.LogType.ROUND
|
||||
assert round_log.status == AgentLog.LogStatus.START
|
||||
assert "inputs" in round_log.data
|
||||
|
||||
def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool call logs have correct structure."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
# Create a parent round log
|
||||
round_log = strategy._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
# Create a tool call log
|
||||
tool_log = strategy._create_log(
|
||||
label="CALL test_tool",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}},
|
||||
parent_id=round_log.id,
|
||||
)
|
||||
|
||||
assert tool_log.label == "CALL test_tool"
|
||||
assert tool_log.log_type == AgentLog.LogType.TOOL_CALL
|
||||
assert tool_log.parent_id == round_log.id
|
||||
assert tool_log.data["tool_name"] == "test_tool"
|
||||
|
||||
|
||||
class TestToolInvocation:
|
||||
"""Tests for tool invocation."""
|
||||
|
||||
def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool invocation uses hook when provided."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_hook = MagicMock()
|
||||
mock_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={"tool_provider_type": "test", "tool_provider": "test_id"},
|
||||
)
|
||||
mock_hook.return_value = ("Tool result", ["file-1"], mock_meta)
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool")
|
||||
|
||||
mock_hook.assert_called_once()
|
||||
assert result == "Tool result"
|
||||
assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list
|
||||
assert meta == mock_meta
|
||||
|
||||
def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool_invoke_hook is None when not provided."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=None,
|
||||
)
|
||||
|
||||
# Verify that tool_invoke_hook is None
|
||||
assert strategy.tool_invoke_hook is None
|
||||
|
||||
|
||||
class TestUsageTracking:
|
||||
"""Tests for usage tracking across rounds."""
|
||||
|
||||
def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context):
|
||||
"""Test that round usage is tracked separately from total."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
# Simulate two rounds of usage
|
||||
total_usage: dict = {"usage": None}
|
||||
round1_usage: dict = {"usage": None}
|
||||
round2_usage: dict = {"usage": None}
|
||||
|
||||
# Round 1
|
||||
usage1 = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
strategy._accumulate_usage(round1_usage, usage1)
|
||||
strategy._accumulate_usage(total_usage, usage1)
|
||||
|
||||
# Round 2
|
||||
usage2 = LLMUsage(
|
||||
prompt_tokens=200,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.2"),
|
||||
completion_tokens=100,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.2"),
|
||||
total_tokens=300,
|
||||
total_price=Decimal("0.4"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
strategy._accumulate_usage(round2_usage, usage2)
|
||||
strategy._accumulate_usage(total_usage, usage2)
|
||||
|
||||
# Verify round usage is separate
|
||||
assert round1_usage["usage"].total_tokens == 150
|
||||
assert round2_usage["usage"].total_tokens == 300
|
||||
# Verify total is accumulated
|
||||
assert total_usage["usage"].total_tokens == 450
|
||||
|
||||
|
||||
class TestPromptMessageHandling:
|
||||
"""Tests for prompt message handling."""
|
||||
|
||||
def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that messages include system and user prompts."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Hello"),
|
||||
]
|
||||
|
||||
# Just verify the messages can be processed
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], SystemPromptMessage)
|
||||
assert isinstance(messages[1], UserPromptMessage)
|
||||
|
||||
def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that assistant messages can contain tool calls."""
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="call_123",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments='{"param1": "value1"}',
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content="I'll help you with that.",
|
||||
tool_calls=[tool_call],
|
||||
)
|
||||
|
||||
assert len(assistant_message.tool_calls) == 1
|
||||
assert assistant_message.tool_calls[0].function.name == "test_tool"
|
||||
|
|
@ -0,0 +1,224 @@
|
|||
"""Tests for ReActStrategy."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import ExecutionContext
|
||||
from core.agent.patterns.react import ReActStrategy
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool."""
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
|
||||
tool = MagicMock()
|
||||
tool.entity.identity.name = "test_tool"
|
||||
tool.entity.identity.provider = "test_provider"
|
||||
|
||||
# Use real PromptMessageTool for proper serialization
|
||||
prompt_tool = PromptMessageTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
tool.to_prompt_message_tool.return_value = prompt_tool
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
class TestReActStrategyInit:
|
||||
"""Tests for ReActStrategy initialization."""
|
||||
|
||||
def test_init_with_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that instruction is stored correctly."""
|
||||
instruction = "You are a helpful assistant."
|
||||
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
assert strategy.instruction == instruction
|
||||
|
||||
def test_init_with_empty_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that empty instruction is handled correctly."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
assert strategy.instruction == ""
|
||||
|
||||
|
||||
class TestBuildPromptWithReactFormat:
|
||||
"""Tests for _build_prompt_with_react_format method."""
|
||||
|
||||
def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that {{tools}} placeholder is replaced."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "You have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
UserPromptMessage(content="Hello"),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# The tools placeholder should be replaced with JSON
|
||||
assert "{{tools}}" not in result[0].content
|
||||
assert "test_tool" in result[0].content
|
||||
|
||||
def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that {{tool_names}} placeholder is replaced."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "Valid actions: {{tool_names}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
assert "{{tool_names}}" not in result[0].content
|
||||
assert '"test_tool"' in result[0].content
|
||||
|
||||
def test_replace_instruction_placeholder(self, mock_model_instance, mock_context):
|
||||
"""Test that {{instruction}} placeholder is replaced."""
|
||||
instruction = "You are a helpful coding assistant."
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
system_content = "{{instruction}}\n\nYou have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True, instruction)
|
||||
|
||||
assert "{{instruction}}" not in result[0].content
|
||||
assert instruction in result[0].content
|
||||
|
||||
def test_no_tools_available_message(self, mock_model_instance, mock_context):
|
||||
"""Test that 'No tools available' is shown when include_tools is False."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "You have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], False)
|
||||
|
||||
assert "No tools available" in result[0].content
|
||||
|
||||
def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context):
|
||||
"""Test that agent scratchpad is appended as AssistantPromptMessage."""
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="System prompt"),
|
||||
UserPromptMessage(content="User query"),
|
||||
]
|
||||
|
||||
scratchpad = [
|
||||
AgentScratchpadUnit(
|
||||
thought="I need to search for information",
|
||||
action_str='{"action": "search", "action_input": "query"}',
|
||||
observation="Search results here",
|
||||
)
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, scratchpad, True)
|
||||
|
||||
# The last message should be an AssistantPromptMessage with scratchpad content
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[-1], AssistantPromptMessage)
|
||||
assert "I need to search for information" in result[-1].content
|
||||
assert "Search results here" in result[-1].content
|
||||
|
||||
def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context):
|
||||
"""Test that empty scratchpad doesn't add extra message."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="System prompt"),
|
||||
UserPromptMessage(content="User query"),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# Should only have the original 2 messages
|
||||
assert len(result) == 2
|
||||
|
||||
def test_original_messages_not_modified(self, mock_model_instance, mock_context):
|
||||
"""Test that original messages list is not modified."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
original_content = "Original system prompt {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=original_content),
|
||||
]
|
||||
|
||||
strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# Original message should not be modified
|
||||
assert messages[0].content == original_content
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
"""Tests for StrategyFactory."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
from core.agent.patterns.react import ReActStrategy
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
class TestStrategyFactory:
|
||||
"""Tests for StrategyFactory.create_strategy method."""
|
||||
|
||||
def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports TOOL_CALL."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL."""
|
||||
model_features = [ModelFeature.MULTI_TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL."""
|
||||
model_features = [ModelFeature.STREAM_TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy is created when model doesn't support tool calling."""
|
||||
model_features = [ModelFeature.VISION] # Only vision, no tool calling
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy is created when model has no features."""
|
||||
model_features: list[ModelFeature] = []
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context):
|
||||
"""Test explicit FUNCTION_CALLING strategy selection with model support."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_explicit_function_calling_strategy_without_support_falls_back_to_react(
|
||||
self, mock_model_instance, mock_context
|
||||
):
|
||||
"""Test that explicit FUNCTION_CALLING falls back to ReAct when not supported."""
|
||||
model_features: list[ModelFeature] = [] # No tool calling support
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
)
|
||||
|
||||
# Should fall back to ReAct since FC is not supported
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test explicit CHAIN_OF_THOUGHT strategy selection."""
|
||||
model_features = [ModelFeature.TOOL_CALL] # Even with tool call support
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_react_strategy_with_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy receives instruction parameter."""
|
||||
model_features: list[ModelFeature] = []
|
||||
instruction = "You are a helpful assistant."
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
assert strategy.instruction == instruction
|
||||
|
||||
def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is passed to the strategy."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
max_iterations = 5
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
assert strategy.max_iterations == max_iterations
|
||||
|
||||
def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test that tool_invoke_hook is passed to the strategy."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
mock_hook = MagicMock()
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
assert strategy.tool_invoke_hook == mock_hook
|
||||
|
|
@ -0,0 +1,388 @@
|
|||
"""Tests for AgentAppRunner."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
"""Tests for _organize_prompt_messages method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
# We'll patch the class to avoid complex initialization
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
|
||||
# Set up required attributes
|
||||
runner.config = MagicMock(spec=AgentEntity)
|
||||
runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner.config.prompt = None
|
||||
|
||||
runner.app_config = MagicMock()
|
||||
runner.app_config.prompt_template = MagicMock()
|
||||
runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant."
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
runner.query = "Hello"
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.model_config = MagicMock()
|
||||
runner.memory = None
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.file_upload_config = None
|
||||
|
||||
return runner
|
||||
|
||||
def test_function_calling_uses_simple_prompt(self, mock_runner):
|
||||
"""Test that function calling strategy uses simple_prompt_template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="You are a helpful assistant.")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with simple_prompt_template
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "You are a helpful assistant."
|
||||
|
||||
def test_chain_of_thought_uses_agent_prompt(self, mock_runner):
|
||||
"""Test that chain of thought strategy uses agent prompt template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
mock_runner.config.prompt = AgentPromptEntity(
|
||||
first_prompt="ReAct prompt template with {{tools}}",
|
||||
next_iteration="Continue...",
|
||||
)
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="ReAct prompt template with {{tools}}")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with agent prompt
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "ReAct prompt template with {{tools}}"
|
||||
|
||||
def test_chain_of_thought_without_prompt_falls_back(self, mock_runner):
|
||||
"""Test that chain of thought without prompt falls back to simple_prompt_template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
mock_runner.config.prompt = None
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="You are a helpful assistant.")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with simple_prompt_template
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "You are a helpful assistant."
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
"""Tests for _init_system_message method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
return runner
|
||||
|
||||
def test_empty_messages_with_template(self, mock_runner):
|
||||
"""Test that system message is created when messages are empty."""
|
||||
result = mock_runner._init_system_message("System template", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert result[0].content == "System template"
|
||||
|
||||
def test_empty_messages_without_template(self, mock_runner):
|
||||
"""Test that empty list is returned when no template and no messages."""
|
||||
result = mock_runner._init_system_message("", [])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_existing_system_message_not_duplicated(self, mock_runner):
|
||||
"""Test that system message is not duplicated if already present."""
|
||||
existing_messages = [
|
||||
SystemPromptMessage(content="Existing system"),
|
||||
UserPromptMessage(content="User message"),
|
||||
]
|
||||
|
||||
result = mock_runner._init_system_message("New template", existing_messages)
|
||||
|
||||
# Should not insert new system message
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Existing system"
|
||||
|
||||
def test_system_message_inserted_when_missing(self, mock_runner):
|
||||
"""Test that system message is inserted when first message is not system."""
|
||||
existing_messages = [
|
||||
UserPromptMessage(content="User message"),
|
||||
]
|
||||
|
||||
result = mock_runner._init_system_message("System template", existing_messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert result[0].content == "System template"
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
"""Tests for _clear_user_prompt_image_messages method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
return runner
|
||||
|
||||
def test_text_content_unchanged(self, mock_runner):
|
||||
"""Test that text content is unchanged."""
|
||||
messages = [
|
||||
UserPromptMessage(content="Plain text message"),
|
||||
]
|
||||
|
||||
result = mock_runner._clear_user_prompt_image_messages(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Plain text message"
|
||||
|
||||
def test_original_messages_not_modified(self, mock_runner):
|
||||
"""Test that original messages are not modified (deep copy)."""
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="Text part"),
|
||||
ImagePromptMessageContent(
|
||||
data="http://example.com/image.jpg",
|
||||
format="url",
|
||||
mime_type="image/jpeg",
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
result = mock_runner._clear_user_prompt_image_messages(messages)
|
||||
|
||||
# Original should still have list content
|
||||
assert isinstance(messages[0].content, list)
|
||||
# Result should have string content
|
||||
assert isinstance(result[0].content, str)
|
||||
|
||||
|
||||
class TestToolInvokeHook:
|
||||
"""Tests for _create_tool_invoke_hook method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
|
||||
runner.user_id = "test-user"
|
||||
runner.tenant_id = "test-tenant"
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.trace_manager = None
|
||||
runner.application_generate_entity.invoke_from = "api"
|
||||
runner.application_generate_entity.app_config = MagicMock()
|
||||
runner.application_generate_entity.app_config.app_id = "test-app"
|
||||
runner.agent_callback = MagicMock()
|
||||
runner.conversation = MagicMock()
|
||||
runner.conversation.id = "test-conversation"
|
||||
runner.queue_manager = MagicMock()
|
||||
runner._current_message_file_ids = []
|
||||
|
||||
return runner
|
||||
|
||||
def test_hook_calls_agent_invoke(self, mock_runner):
|
||||
"""Test that the hook calls ToolEngine.agent_invoke."""
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = "test-message"
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={
|
||||
"tool_provider_type": "test_provider",
|
||||
"tool_provider": "test_id",
|
||||
},
|
||||
)
|
||||
|
||||
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
|
||||
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
|
||||
|
||||
hook = mock_runner._create_tool_invoke_hook(mock_message)
|
||||
result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool")
|
||||
|
||||
# Verify ToolEngine.agent_invoke was called
|
||||
mock_engine.agent_invoke.assert_called_once()
|
||||
|
||||
# Verify return values
|
||||
assert result_content == "Tool result"
|
||||
assert result_files == ["file-1", "file-2"]
|
||||
assert result_meta == mock_tool_meta
|
||||
|
||||
def test_hook_publishes_file_events(self, mock_runner):
|
||||
"""Test that the hook publishes QueueMessageFileEvent for files."""
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = "test-message"
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={},
|
||||
)
|
||||
|
||||
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
|
||||
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
|
||||
|
||||
hook = mock_runner._create_tool_invoke_hook(mock_message)
|
||||
hook(mock_tool, {}, "test_tool")
|
||||
|
||||
# Verify file events were published
|
||||
assert mock_runner.queue_manager.publish.call_count == 2
|
||||
assert mock_runner._current_message_file_ids == ["file-1", "file-2"]
|
||||
|
||||
|
||||
class TestAgentLogProcessing:
|
||||
"""Tests for AgentLog processing in run method."""
|
||||
|
||||
def test_agent_log_status_enum(self):
|
||||
"""Test AgentLog status enum values."""
|
||||
assert AgentLog.LogStatus.START == "start"
|
||||
assert AgentLog.LogStatus.SUCCESS == "success"
|
||||
assert AgentLog.LogStatus.ERROR == "error"
|
||||
|
||||
def test_agent_log_metadata_enum(self):
|
||||
"""Test AgentLog metadata enum values."""
|
||||
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
|
||||
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
|
||||
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
|
||||
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
|
||||
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
|
||||
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
|
||||
|
||||
def test_agent_result_structure(self):
|
||||
"""Test AgentResult structure."""
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
result = AgentResult(
|
||||
text="Final answer",
|
||||
files=[],
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
assert result.text == "Final answer"
|
||||
assert result.files == []
|
||||
assert result.usage == usage
|
||||
assert result.finish_reason == "stop"
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
"""Tests for _organize_user_query method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
runner.files = []
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.file_upload_config = None
|
||||
return runner
|
||||
|
||||
def test_simple_query_without_files(self, mock_runner):
|
||||
"""Test organizing a simple query without files."""
|
||||
result = mock_runner._organize_user_query("Hello world", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert result[0].content == "Hello world"
|
||||
|
||||
def test_query_with_files(self, mock_runner):
|
||||
"""Test organizing a query with files."""
|
||||
from core.file.models import File
|
||||
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_runner.files = [mock_file]
|
||||
|
||||
with patch("core.agent.agent_app_runner.file_manager") as mock_fm:
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent(
|
||||
data="http://example.com/image.jpg",
|
||||
format="url",
|
||||
mime_type="image/jpeg",
|
||||
)
|
||||
|
||||
result = mock_runner._organize_user_query("Describe this image", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert isinstance(result[0].content, list)
|
||||
assert len(result[0].content) == 2 # Image + Text
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
"""Tests for agent entities."""
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext entity."""
|
||||
|
||||
def test_create_with_all_fields(self):
|
||||
"""Test creating ExecutionContext with all fields."""
|
||||
context = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
conversation_id="conv-789",
|
||||
message_id="msg-012",
|
||||
tenant_id="tenant-345",
|
||||
)
|
||||
|
||||
assert context.user_id == "user-123"
|
||||
assert context.app_id == "app-456"
|
||||
assert context.conversation_id == "conv-789"
|
||||
assert context.message_id == "msg-012"
|
||||
assert context.tenant_id == "tenant-345"
|
||||
|
||||
def test_create_minimal(self):
|
||||
"""Test creating minimal ExecutionContext."""
|
||||
context = ExecutionContext.create_minimal(user_id="user-123")
|
||||
|
||||
assert context.user_id == "user-123"
|
||||
assert context.app_id is None
|
||||
assert context.conversation_id is None
|
||||
assert context.message_id is None
|
||||
assert context.tenant_id is None
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting ExecutionContext to dictionary."""
|
||||
context = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
conversation_id="conv-789",
|
||||
message_id="msg-012",
|
||||
tenant_id="tenant-345",
|
||||
)
|
||||
|
||||
result = context.to_dict()
|
||||
|
||||
assert result == {
|
||||
"user_id": "user-123",
|
||||
"app_id": "app-456",
|
||||
"conversation_id": "conv-789",
|
||||
"message_id": "msg-012",
|
||||
"tenant_id": "tenant-345",
|
||||
}
|
||||
|
||||
def test_with_updates(self):
|
||||
"""Test creating new context with updates."""
|
||||
original = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
)
|
||||
|
||||
updated = original.with_updates(message_id="msg-789")
|
||||
|
||||
# Original should be unchanged
|
||||
assert original.message_id is None
|
||||
# Updated should have new value
|
||||
assert updated.message_id == "msg-789"
|
||||
assert updated.user_id == "user-123"
|
||||
assert updated.app_id == "app-456"
|
||||
|
||||
|
||||
class TestAgentLog:
|
||||
"""Tests for AgentLog entity."""
|
||||
|
||||
def test_create_log_with_required_fields(self):
|
||||
"""Test creating AgentLog with required fields."""
|
||||
log = AgentLog(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
|
||||
assert log.label == "ROUND 1"
|
||||
assert log.log_type == AgentLog.LogType.ROUND
|
||||
assert log.status == AgentLog.LogStatus.START
|
||||
assert log.data == {"key": "value"}
|
||||
assert log.id is not None # Auto-generated
|
||||
assert log.parent_id is None
|
||||
assert log.error is None
|
||||
|
||||
def test_log_type_enum(self):
|
||||
"""Test LogType enum values."""
|
||||
assert AgentLog.LogType.ROUND == "round"
|
||||
assert AgentLog.LogType.THOUGHT == "thought"
|
||||
assert AgentLog.LogType.TOOL_CALL == "tool_call"
|
||||
|
||||
def test_log_status_enum(self):
|
||||
"""Test LogStatus enum values."""
|
||||
assert AgentLog.LogStatus.START == "start"
|
||||
assert AgentLog.LogStatus.SUCCESS == "success"
|
||||
assert AgentLog.LogStatus.ERROR == "error"
|
||||
|
||||
def test_log_metadata_enum(self):
|
||||
"""Test LogMetadata enum values."""
|
||||
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
|
||||
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
|
||||
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
|
||||
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
|
||||
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
|
||||
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
|
||||
|
||||
|
||||
class TestAgentScratchpadUnit:
|
||||
"""Tests for AgentScratchpadUnit entity."""
|
||||
|
||||
def test_is_final_with_final_answer_action(self):
|
||||
"""Test is_final returns True for Final Answer action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="I know the answer",
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name="Final Answer",
|
||||
action_input="The answer is 42",
|
||||
),
|
||||
)
|
||||
|
||||
assert unit.is_final() is True
|
||||
|
||||
def test_is_final_with_tool_action(self):
|
||||
"""Test is_final returns False for tool action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="I need to search",
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name="search",
|
||||
action_input={"query": "test"},
|
||||
),
|
||||
)
|
||||
|
||||
assert unit.is_final() is False
|
||||
|
||||
def test_is_final_with_no_action(self):
|
||||
"""Test is_final returns True when no action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="Just thinking",
|
||||
)
|
||||
|
||||
assert unit.is_final() is True
|
||||
|
||||
def test_action_to_dict(self):
|
||||
"""Test Action.to_dict method."""
|
||||
action = AgentScratchpadUnit.Action(
|
||||
action_name="search",
|
||||
action_input={"query": "test"},
|
||||
)
|
||||
|
||||
result = action.to_dict()
|
||||
|
||||
assert result == {
|
||||
"action": "search",
|
||||
"action_input": {"query": "test"},
|
||||
}
|
||||
|
||||
|
||||
class TestAgentEntity:
|
||||
"""Tests for AgentEntity."""
|
||||
|
||||
def test_strategy_enum(self):
|
||||
"""Test Strategy enum values."""
|
||||
assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought"
|
||||
assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling"
|
||||
|
||||
def test_create_with_prompt(self):
|
||||
"""Test creating AgentEntity with prompt."""
|
||||
prompt = AgentPromptEntity(
|
||||
first_prompt="You are a helpful assistant.",
|
||||
next_iteration="Continue thinking...",
|
||||
)
|
||||
|
||||
entity = AgentEntity(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
prompt=prompt,
|
||||
max_iteration=5,
|
||||
)
|
||||
|
||||
assert entity.provider == "openai"
|
||||
assert entity.model == "gpt-4"
|
||||
assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
assert entity.prompt == prompt
|
||||
assert entity.max_iteration == 5
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
|
||||
class DummyQueueManager:
|
||||
def __init__(self) -> None:
|
||||
self.published = []
|
||||
|
||||
def publish(self, event, publish_from: PublishFrom) -> None:
|
||||
self.published.append((event, publish_from))
|
||||
|
||||
|
||||
def test_skip_empty_final_chunk() -> None:
|
||||
queue_manager = DummyQueueManager()
|
||||
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app")
|
||||
|
||||
empty_final_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event)
|
||||
assert queue_manager.published == []
|
||||
|
||||
normal_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="hi",
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry=MagicMock(), event=normal_event)
|
||||
|
||||
assert len(queue_manager.published) == 1
|
||||
published_event, publish_from = queue_manager.published[0]
|
||||
assert publish_from == PublishFrom.APPLICATION_MANAGER
|
||||
assert published_event.text == "hi"
|
||||
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
"""Tests for ResponseStreamCoordinator object field streaming."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.tool_entities import ToolResultStatus
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.template import Template, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class TestResponseCoordinatorObjectStreaming:
|
||||
"""Test streaming of object-type variables with child fields."""
|
||||
|
||||
def test_object_field_streaming(self):
|
||||
"""Test that when selecting an object variable, all child field streams are forwarded."""
|
||||
# Create mock graph and variable pool
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
# Mock nodes
|
||||
llm_node = MagicMock()
|
||||
llm_node.id = "llm_node"
|
||||
llm_node.node_type = NodeType.LLM
|
||||
llm_node.execution_type = MagicMock()
|
||||
llm_node.blocks_variable_output = MagicMock(return_value=False)
|
||||
|
||||
response_node = MagicMock()
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
response_node.execution_type = MagicMock()
|
||||
response_node.blocks_variable_output = MagicMock(return_value=False)
|
||||
|
||||
# Mock template for response node
|
||||
response_node.node_data = MagicMock(spec=BaseNodeData)
|
||||
response_node.node_data.answer = "{{#llm_node.generation#}}"
|
||||
|
||||
graph.nodes = {
|
||||
"llm_node": llm_node,
|
||||
"response_node": response_node,
|
||||
}
|
||||
graph.root_node = llm_node
|
||||
graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
|
||||
# Create coordinator
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
|
||||
# Track execution
|
||||
coordinator.track_node_execution("llm_node", "exec_123")
|
||||
coordinator.track_node_execution("response_node", "exec_456")
|
||||
|
||||
# Simulate streaming events for child fields of generation object
|
||||
# 1. Content stream
|
||||
content_event_1 = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "content"],
|
||||
chunk="Hello",
|
||||
is_final=False,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
content_event_2 = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "content"],
|
||||
chunk=" world",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
|
||||
# 2. Tool call stream
|
||||
tool_call_event = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "tool_calls"],
|
||||
chunk='{"query": "test"}',
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Tool result stream
|
||||
tool_result_event = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "tool_results"],
|
||||
chunk="Found 10 results",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="search",
|
||||
output="Found 10 results",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
)
|
||||
|
||||
# Intercept these events
|
||||
coordinator.intercept_event(content_event_1)
|
||||
coordinator.intercept_event(tool_call_event)
|
||||
coordinator.intercept_event(tool_result_event)
|
||||
coordinator.intercept_event(content_event_2)
|
||||
|
||||
# Verify that all child streams are buffered
|
||||
assert ("llm_node", "generation", "content") in coordinator._stream_buffers
|
||||
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
|
||||
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
|
||||
|
||||
# Verify payloads are preserved in buffered events
|
||||
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
|
||||
assert buffered_call.tool_call is not None
|
||||
assert buffered_call.tool_call.id == "call_123"
|
||||
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
|
||||
assert buffered_result.tool_result is not None
|
||||
assert buffered_result.tool_result.status == "success"
|
||||
|
||||
# Verify we can find child streams
|
||||
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
|
||||
assert len(child_streams) == 3
|
||||
assert ("llm_node", "generation", "content") in child_streams
|
||||
assert ("llm_node", "generation", "tool_calls") in child_streams
|
||||
assert ("llm_node", "generation", "tool_results") in child_streams
|
||||
|
||||
def test_find_child_streams(self):
|
||||
"""Test the _find_child_streams method."""
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
|
||||
# Add some mock streams
|
||||
coordinator._stream_buffers = {
|
||||
("node1", "generation", "content"): [],
|
||||
("node1", "generation", "tool_calls"): [],
|
||||
("node1", "generation", "thought"): [],
|
||||
("node1", "text"): [], # Not a child of generation
|
||||
("node2", "generation", "content"): [], # Different node
|
||||
}
|
||||
|
||||
# Find children of node1.generation
|
||||
children = coordinator._find_child_streams(["node1", "generation"])
|
||||
|
||||
assert len(children) == 3
|
||||
assert ("node1", "generation", "content") in children
|
||||
assert ("node1", "generation", "tool_calls") in children
|
||||
assert ("node1", "generation", "thought") in children
|
||||
assert ("node1", "text") not in children
|
||||
assert ("node2", "generation", "content") not in children
|
||||
|
||||
def test_find_child_streams_with_closed_streams(self):
|
||||
"""Test that _find_child_streams also considers closed streams."""
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
|
||||
# Add some streams - some buffered, some closed
|
||||
coordinator._stream_buffers = {
|
||||
("node1", "generation", "content"): [],
|
||||
}
|
||||
coordinator._closed_streams = {
|
||||
("node1", "generation", "tool_calls"),
|
||||
("node1", "generation", "thought"),
|
||||
}
|
||||
|
||||
# Should find all children regardless of whether they're in buffers or closed
|
||||
children = coordinator._find_child_streams(["node1", "generation"])
|
||||
|
||||
assert len(children) == 3
|
||||
assert ("node1", "generation", "content") in children
|
||||
assert ("node1", "generation", "tool_calls") in children
|
||||
assert ("node1", "generation", "thought") in children
|
||||
|
||||
def test_special_selector_rewrites_to_active_response_node(self):
|
||||
"""Ensure special selectors attribute streams to the active response node."""
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
response_node = MagicMock()
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
graph.nodes = {"response_node": response_node}
|
||||
graph.root_node = response_node
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
coordinator.track_node_execution("response_node", "exec_resp")
|
||||
|
||||
coordinator._active_session = ResponseSession(
|
||||
node_id="response_node",
|
||||
template=Template(segments=[VariableSegment(selector=["sys", "foo"])]),
|
||||
)
|
||||
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id="stream_1",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["sys", "foo"],
|
||||
chunk="hi",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
|
||||
coordinator._stream_buffers[("sys", "foo")] = [event]
|
||||
coordinator._stream_positions[("sys", "foo")] = 0
|
||||
coordinator._closed_streams.add(("sys", "foo"))
|
||||
|
||||
events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"]))
|
||||
|
||||
assert is_complete
|
||||
assert len(events) == 1
|
||||
rewritten = events[0]
|
||||
assert rewritten.node_id == "response_node"
|
||||
assert rewritten.id == "exec_resp"
|
||||
|
|
@ -0,0 +1,328 @@
|
|||
"""Tests for StreamChunkEvent and its subclasses."""
|
||||
|
||||
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.node_events import (
|
||||
ChunkType,
|
||||
StreamChunkEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
|
||||
|
||||
class TestChunkType:
|
||||
"""Tests for ChunkType enum."""
|
||||
|
||||
def test_chunk_type_values(self):
|
||||
"""Test that ChunkType has expected values."""
|
||||
assert ChunkType.TEXT == "text"
|
||||
assert ChunkType.TOOL_CALL == "tool_call"
|
||||
assert ChunkType.TOOL_RESULT == "tool_result"
|
||||
assert ChunkType.THOUGHT == "thought"
|
||||
|
||||
def test_chunk_type_is_str_enum(self):
|
||||
"""Test that ChunkType values are strings."""
|
||||
for chunk_type in ChunkType:
|
||||
assert isinstance(chunk_type.value, str)
|
||||
|
||||
|
||||
class TestStreamChunkEvent:
|
||||
"""Tests for base StreamChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating StreamChunkEvent with required fields."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "text"],
|
||||
chunk="Hello",
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "text"]
|
||||
assert event.chunk == "Hello"
|
||||
assert event.is_final is False
|
||||
assert event.chunk_type == ChunkType.TEXT
|
||||
|
||||
def test_create_with_all_fields(self):
|
||||
"""Test creating StreamChunkEvent with all fields."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "output"],
|
||||
chunk="World",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "output"]
|
||||
assert event.chunk == "World"
|
||||
assert event.is_final is True
|
||||
assert event.chunk_type == ChunkType.TEXT
|
||||
|
||||
def test_default_chunk_type_is_text(self):
|
||||
"""Test that default chunk_type is TEXT."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "text"],
|
||||
chunk="test",
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TEXT
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "text"],
|
||||
chunk="Hello",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["selector"] == ["node1", "text"]
|
||||
assert data["chunk"] == "Hello"
|
||||
assert data["is_final"] is True
|
||||
assert data["chunk_type"] == "text"
|
||||
|
||||
|
||||
class TestToolCallChunkEvent:
|
||||
"""Tests for ToolCallChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating ToolCallChunkEvent with required fields."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_calls"]
|
||||
assert event.chunk == '{"city": "Beijing"}'
|
||||
assert event.tool_call.id == "call_123"
|
||||
assert event.tool_call.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
|
||||
def test_chunk_type_is_tool_call(self):
|
||||
"""Test that chunk_type is always TOOL_CALL."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
|
||||
def test_tool_arguments_field(self):
|
||||
"""Test tool_arguments field."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"param": "value"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_call.arguments == '{"param": "value"}'
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
arguments='{"city": "Beijing"}',
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_call"
|
||||
assert data["tool_call"]["id"] == "call_123"
|
||||
assert data["tool_call"]["name"] == "weather"
|
||||
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
class TestToolResultChunkEvent:
|
||||
"""Tests for ToolResultChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating ToolResultChunkEvent with required fields."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny, 25°C",
|
||||
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_results"]
|
||||
assert event.chunk == "Weather: Sunny, 25°C"
|
||||
assert event.tool_result.id == "call_123"
|
||||
assert event.tool_result.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
|
||||
def test_chunk_type_is_tool_result(self):
|
||||
"""Test that chunk_type is always TOOL_RESULT."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
|
||||
def test_tool_files_default_empty(self):
|
||||
"""Test that tool_files defaults to empty list."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.tool_result.files == []
|
||||
|
||||
def test_tool_files_with_values(self):
|
||||
"""Test tool_files with file IDs."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
files=["file_1", "file_2"],
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_result.files == ["file_1", "file_2"]
|
||||
|
||||
def test_tool_error_output(self):
|
||||
"""Test error output captured in tool_result."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
output="Tool execution failed",
|
||||
status=ToolResultStatus.ERROR,
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_result.output == "Tool execution failed"
|
||||
assert event.tool_result.status == ToolResultStatus.ERROR
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
output="Weather: Sunny",
|
||||
files=["file_1"],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_result"
|
||||
assert data["tool_result"]["id"] == "call_123"
|
||||
assert data["tool_result"]["name"] == "weather"
|
||||
assert data["tool_result"]["files"] == ["file_1"]
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
class TestThoughtChunkEvent:
|
||||
"""Tests for ThoughtChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating ThoughtChunkEvent with required fields."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="I need to query the weather...",
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "thought"]
|
||||
assert event.chunk == "I need to query the weather..."
|
||||
assert event.chunk_type == ChunkType.THOUGHT
|
||||
|
||||
def test_chunk_type_is_thought(self):
|
||||
"""Test that chunk_type is always THOUGHT."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="thinking...",
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.THOUGHT
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="I need to analyze this...",
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "thought"
|
||||
assert data["chunk"] == "I need to analyze this..."
|
||||
assert data["is_final"] is False
|
||||
|
||||
|
||||
class TestEventInheritance:
|
||||
"""Tests for event inheritance relationships."""
|
||||
|
||||
def test_tool_call_is_stream_chunk(self):
|
||||
"""Test that ToolCallChunkEvent is a StreamChunkEvent."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(id="call_123", name="test", arguments=None),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
|
||||
def test_tool_result_is_stream_chunk(self):
|
||||
"""Test that ToolResultChunkEvent is a StreamChunkEvent."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(id="call_123", name="test"),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
|
||||
def test_thought_is_stream_chunk(self):
|
||||
"""Test that ThoughtChunkEvent is a StreamChunkEvent."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="thinking...",
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
|
||||
def test_all_events_have_common_fields(self):
|
||||
"""Test that all events have common StreamChunkEvent fields."""
|
||||
events = [
|
||||
StreamChunkEvent(selector=["n", "t"], chunk="a"),
|
||||
ToolCallChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="b",
|
||||
tool_call=ToolCall(id="1", name="t", arguments=None),
|
||||
),
|
||||
ToolResultChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="c",
|
||||
tool_result=ToolResult(id="1", name="t"),
|
||||
),
|
||||
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
assert hasattr(event, "selector")
|
||||
assert hasattr(event, "chunk")
|
||||
assert hasattr(event, "is_final")
|
||||
assert hasattr(event, "chunk_type")
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
import types
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities import ToolCallResult
|
||||
from core.workflow.entities.tool_entities import ToolResultStatus
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
|
||||
class _StubModelInstance:
|
||||
"""Minimal stub to satisfy _stream_llm_events signature."""
|
||||
|
||||
provider_model_bundle = None
|
||||
|
||||
|
||||
def _drain(generator: Generator[NodeEventBase, None, Any]):
|
||||
events: list = []
|
||||
try:
|
||||
while True:
|
||||
events.append(next(generator))
|
||||
except StopIteration as exc:
|
||||
return events, exc.value
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_deduct_llm_quota(monkeypatch):
|
||||
# Avoid touching real quota logic during unit tests
|
||||
monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None)
|
||||
|
||||
|
||||
def _make_llm_node(reasoning_format: str) -> LLMNode:
|
||||
node = LLMNode.__new__(LLMNode)
|
||||
object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[]))
|
||||
object.__setattr__(node, "tenant_id", "tenant")
|
||||
return node
|
||||
|
||||
|
||||
def test_stream_llm_events_extracts_reasoning_for_tagged():
|
||||
node = _make_llm_node(reasoning_format="tagged")
|
||||
tagged_text = "<think>Thought</think>Answer"
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
def generator():
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=tagged_text,
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
reasoning_content="",
|
||||
structured_output=None,
|
||||
)
|
||||
|
||||
events, returned = _drain(
|
||||
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
|
||||
)
|
||||
|
||||
assert events == []
|
||||
clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned
|
||||
assert clean_text == tagged_text # original preserved for output
|
||||
assert reasoning_content == "" # tagged mode keeps reasoning separate
|
||||
assert gen_clean == "Answer" # stripped content for generation
|
||||
assert gen_reasoning == "Thought" # reasoning extracted from <think> tag
|
||||
assert ret_usage == usage
|
||||
assert finish_reason == "stop"
|
||||
assert structured is None
|
||||
assert gen_data is None
|
||||
|
||||
# generation building should include reasoning and sequence
|
||||
generation_content = gen_clean or clean_text
|
||||
sequence = [
|
||||
{"type": "reasoning", "index": 0},
|
||||
{"type": "content", "start": 0, "end": len(generation_content)},
|
||||
]
|
||||
assert sequence == [
|
||||
{"type": "reasoning", "index": 0},
|
||||
{"type": "content", "start": 0, "end": len("Answer")},
|
||||
]
|
||||
|
||||
|
||||
def test_stream_llm_events_no_reasoning_results_in_empty_sequence():
|
||||
node = _make_llm_node(reasoning_format="tagged")
|
||||
plain_text = "Hello world"
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
def generator():
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=plain_text,
|
||||
usage=usage,
|
||||
finish_reason=None,
|
||||
reasoning_content="",
|
||||
structured_output=None,
|
||||
)
|
||||
|
||||
events, returned = _drain(
|
||||
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
|
||||
)
|
||||
|
||||
assert events == []
|
||||
_, _, gen_reasoning, gen_clean, *_ = returned
|
||||
generation_content = gen_clean or plain_text
|
||||
assert gen_reasoning == ""
|
||||
assert generation_content == plain_text
|
||||
# Empty reasoning should imply empty sequence in generation construction
|
||||
sequence = []
|
||||
assert sequence == []
|
||||
|
||||
|
||||
def test_serialize_tool_call_strips_files_to_ids():
|
||||
file_cls = pytest.importorskip("core.file").File
|
||||
file_type = pytest.importorskip("core.file.enums").FileType
|
||||
transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod
|
||||
|
||||
file_with_id = file_cls(
|
||||
id="f1",
|
||||
tenant_id="t",
|
||||
type=file_type.IMAGE,
|
||||
transfer_method=transfer_method.REMOTE_URL,
|
||||
remote_url="http://example.com/f1",
|
||||
storage_key="k1",
|
||||
)
|
||||
file_with_related = file_cls(
|
||||
id=None,
|
||||
tenant_id="t",
|
||||
type=file_type.IMAGE,
|
||||
transfer_method=transfer_method.REMOTE_URL,
|
||||
related_id="rel2",
|
||||
remote_url="http://example.com/f2",
|
||||
storage_key="k2",
|
||||
)
|
||||
tool_call = ToolCallResult(
|
||||
id="tc",
|
||||
name="do",
|
||||
arguments='{"a":1}',
|
||||
output="ok",
|
||||
files=[file_with_id, file_with_related],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
)
|
||||
|
||||
serialized = LLMNode._serialize_tool_call(tool_call)
|
||||
|
||||
assert serialized["files"] == ["f1", "rel2"]
|
||||
assert serialized["id"] == "tc"
|
||||
assert serialized["name"] == "do"
|
||||
assert serialized["arguments"] == '{"a":1}'
|
||||
assert serialized["output"] == "ok"
|
||||
|
||||
Loading…
Reference in New Issue