From 2b23c43434b535dc1aba2f50e1da57d661c26991 Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 9 Dec 2025 11:26:02 +0800 Subject: [PATCH] feat: add agent package --- api/controllers/console/app/message.py | 1 + api/controllers/console/app/workflow_run.py | 1 - api/controllers/service_api/app/message.py | 1 + api/controllers/web/message.py | 1 + api/core/agent/agent_app_runner.py | 358 ++++++++++++++ api/core/agent/base_agent_runner.py | 13 +- api/core/agent/cot_agent_runner.py | 431 ---------------- api/core/agent/cot_chat_agent_runner.py | 118 ----- api/core/agent/cot_completion_agent_runner.py | 87 ---- api/core/agent/entities.py | 93 ++++ api/core/agent/fc_agent_runner.py | 465 ------------------ api/core/agent/patterns/README.md | 67 +++ api/core/agent/patterns/__init__.py | 19 + api/core/agent/patterns/base.py | 444 +++++++++++++++++ api/core/agent/patterns/function_call.py | 273 ++++++++++ api/core/agent/patterns/react.py | 402 +++++++++++++++ api/core/agent/patterns/strategy_factory.py | 107 ++++ .../advanced_chat/generate_task_pipeline.py | 177 ++++++- api/core/app/apps/agent_chat/app_runner.py | 24 +- .../common/workflow_response_converter.py | 2 +- .../apps/workflow/generate_task_pipeline.py | 38 +- api/core/app/apps/workflow_app_runner.py | 9 + .../app/entities/llm_generation_entities.py | 69 +++ api/core/app/entities/queue_entities.py | 31 ++ api/core/app/entities/task_entities.py | 55 ++- .../easy_ui_based_generate_task_pipeline.py | 83 +++- .../task_pipeline/message_cycle_manager.py | 27 +- ...hemy_workflow_node_execution_repository.py | 89 ++++ api/core/tools/__base/tool.py | 55 +++ api/core/workflow/enums.py | 1 + .../response_coordinator/coordinator.py | 167 +++++-- api/core/workflow/graph_events/__init__.py | 2 + api/core/workflow/graph_events/node.py | 27 +- api/core/workflow/node_events/__init__.py | 8 + api/core/workflow/node_events/node.py | 40 +- api/core/workflow/nodes/base/node.py | 56 +++ api/core/workflow/nodes/llm/__init__.py | 2 + api/core/workflow/nodes/llm/entities.py | 28 ++ api/core/workflow/nodes/llm/llm_utils.py | 92 ++++ api/core/workflow/nodes/llm/node.py | 465 +++++++++++++++++- api/fields/conversation_fields.py | 1 + api/fields/message_fields.py | 1 + api/fields/workflow_run_fields.py | 1 + api/models/__init__.py | 2 + api/models/model.py | 97 ++++ api/services/llm_generation_service.py | 131 +++++ api/services/workflow_run_service.py | 22 +- .../core/agent/patterns/__init__.py | 0 .../core/agent/patterns/test_base.py | 324 ++++++++++++ .../core/agent/patterns/test_function_call.py | 332 +++++++++++++ .../core/agent/patterns/test_react.py | 224 +++++++++ .../agent/patterns/test_strategy_factory.py | 203 ++++++++ .../core/agent/test_agent_app_runner.py | 388 +++++++++++++++ .../unit_tests/core/agent/test_entities.py | 191 +++++++ .../graph_engine/test_response_coordinator.py | 169 +++++++ .../node_events/test_stream_chunk_events.py | 336 +++++++++++++ web/app/components/workflow/constants.ts | 4 + .../nodes/agent/components/tool-icon.tsx | 6 +- .../nodes/llm/components/tools-config.tsx | 58 +++ .../workflow/nodes/llm/constants.ts | 41 ++ .../components/workflow/nodes/llm/default.ts | 1 + .../components/workflow/nodes/llm/node.tsx | 41 +- .../components/workflow/nodes/llm/panel.tsx | 28 ++ .../components/workflow/nodes/llm/types.ts | 2 + .../workflow/nodes/llm/use-config.ts | 100 +++- .../run/agent-log/agent-log-trigger.tsx | 11 +- web/app/components/workflow/run/node.tsx | 3 +- .../components/workflow/run/result-panel.tsx | 3 +- .../run/utils/format-log/agent/index.ts | 2 +- web/i18n/en-US/workflow.ts | 4 + web/i18n/zh-Hans/workflow.ts | 4 + 71 files changed, 5945 insertions(+), 1213 deletions(-) create mode 100644 api/core/agent/agent_app_runner.py delete mode 100644 api/core/agent/cot_agent_runner.py delete mode 100644 api/core/agent/cot_chat_agent_runner.py delete mode 100644 api/core/agent/cot_completion_agent_runner.py delete mode 100644 api/core/agent/fc_agent_runner.py create mode 100644 api/core/agent/patterns/README.md create mode 100644 api/core/agent/patterns/__init__.py create mode 100644 api/core/agent/patterns/base.py create mode 100644 api/core/agent/patterns/function_call.py create mode 100644 api/core/agent/patterns/react.py create mode 100644 api/core/agent/patterns/strategy_factory.py create mode 100644 api/core/app/entities/llm_generation_entities.py create mode 100644 api/services/llm_generation_service.py create mode 100644 api/tests/unit_tests/core/agent/patterns/__init__.py create mode 100644 api/tests/unit_tests/core/agent/patterns/test_base.py create mode 100644 api/tests/unit_tests/core/agent/patterns/test_function_call.py create mode 100644 api/tests/unit_tests/core/agent/patterns/test_react.py create mode 100644 api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py create mode 100644 api/tests/unit_tests/core/agent/test_agent_app_runner.py create mode 100644 api/tests/unit_tests/core/agent/test_entities.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py create mode 100644 api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py create mode 100644 web/app/components/workflow/nodes/llm/components/tools-config.tsx create mode 100644 web/app/components/workflow/nodes/llm/constants.ts diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 377297c84c..6b5b0d9eb3 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -201,6 +201,7 @@ message_detail_model = console_ns.model( "status": fields.String, "error": fields.String, "parent_message_id": fields.String, + "generation_detail": fields.Raw, }, ) diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 8f1871f1e9..8360785d19 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -359,7 +359,6 @@ class WorkflowRunNodeExecutionListApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_node_execution_list_model) def get(self, app_model: App, run_id): """ Get workflow run node execution list diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index b8e5ed28e4..e134253547 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -86,6 +86,7 @@ def build_message_model(api_or_ns: Api | Namespace): "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "status": fields.String, "error": fields.String, + "generation_detail": fields.Raw, } return api_or_ns.model("Message", message_fields) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 9f9aa4838c..afa935afa6 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -55,6 +55,7 @@ class MessageListApi(WebApiResource): "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, + "generation_detail": fields.Raw, } message_infinite_scroll_pagination_fields = { diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py new file mode 100644 index 0000000000..9be5be5c7c --- /dev/null +++ b/api/core/agent/agent_app_runner.py @@ -0,0 +1,358 @@ +import logging +from collections.abc import Generator +from copy import deepcopy +from typing import Any + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.entities import AgentEntity, AgentLog, AgentResult +from core.agent.patterns.strategy_factory import StrategyFactory +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.file import file_manager +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMUsage, + PromptMessage, + PromptMessageContentType, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine +from models.model import Message + +logger = logging.getLogger(__name__) + + +class AgentAppRunner(BaseAgentRunner): + def _create_tool_invoke_hook(self, message: Message): + """ + Create a tool invoke hook that uses ToolEngine.agent_invoke. + This hook handles file creation and returns proper meta information. + """ + # Get trace manager from app generate entity + trace_manager = self.application_generate_entity.trace_manager + + def tool_invoke_hook( + tool: Tool, tool_args: dict[str, Any], tool_name: str + ) -> tuple[str, list[str], ToolInvokeMeta]: + """Hook that uses agent_invoke for proper file and meta handling.""" + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters=tool_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback, + trace_manager=trace_manager, + app_id=self.application_generate_entity.app_config.app_id, + message_id=message.id, + conversation_id=self.conversation.id, + ) + + # Publish files and track IDs + for message_file_id in message_files: + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), + PublishFrom.APPLICATION_MANAGER, + ) + self._current_message_file_ids.append(message_file_id) + + return tool_invoke_response, message_files, tool_invoke_meta + + return tool_invoke_hook + + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: + """ + Run Agent application + """ + self.query = query + app_generate_entity = self.application_generate_entity + + app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" + + # convert tools into ModelRuntime Tool format + tool_instances, _ = self._init_prompt_tools() + + assert app_config.agent + + # Create tool invoke hook for agent_invoke + tool_invoke_hook = self._create_tool_invoke_hook(message) + + # Get instruction for ReAct strategy + instruction = self.app_config.prompt_template.simple_prompt_template or "" + + # Use factory to create appropriate strategy + strategy = StrategyFactory.create_strategy( + model_features=self.model_features, + model_instance=self.model_instance, + tools=list(tool_instances.values()), + files=list(self.files), + max_iterations=app_config.agent.max_iteration, + context=self.build_execution_context(), + agent_strategy=self.config.strategy, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Initialize state variables + current_agent_thought_id = None + has_published_thought = False + current_tool_name: str | None = None + self._current_message_file_ids = [] + + # organize prompt messages + prompt_messages = self._organize_prompt_messages() + + # Run strategy + generator = strategy.run( + prompt_messages=prompt_messages, + model_parameters=app_generate_entity.model_conf.parameters, + stop=app_generate_entity.model_conf.stop, + stream=True, + ) + + # Consume generator and collect result + result: AgentResult | None = None + try: + while True: + try: + output = next(generator) + except StopIteration as e: + # Generator finished, get the return value + result = e.value + break + + if isinstance(output, LLMResultChunk): + # Handle LLM chunk + if current_agent_thought_id and not has_published_thought: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + has_published_thought = True + + yield output + + elif isinstance(output, AgentLog): + # Handle Agent Log using log_type for type-safe dispatch + if output.status == AgentLog.LogStatus.START: + if output.log_type == AgentLog.LogType.ROUND: + # Start of a new round + message_file_ids: list[str] = [] + current_agent_thought_id = self.create_agent_thought( + message_id=message.id, + message="", + tool_name="", + tool_input="", + messages_ids=message_file_ids, + ) + has_published_thought = False + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call start - extract data from structured fields + current_tool_name = output.data.get("tool_name", "") + tool_input = output.data.get("tool_args", {}) + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=current_tool_name, + tool_input=tool_input, + thought=None, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.status == AgentLog.LogStatus.SUCCESS: + if output.log_type == AgentLog.LogType.THOUGHT: + pass + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call finished + tool_output = output.data.get("output") + # Get meta from strategy output (now properly populated) + tool_meta = output.data.get("meta") + + # Wrap tool_meta with tool_name as key (required by agent_service) + if tool_meta and current_tool_name: + tool_meta = {current_tool_name: tool_meta} + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=None, + observation=tool_output, + tool_invoke_meta=tool_meta, + answer=None, + messages_ids=self._current_message_file_ids, + ) + # Clear message file ids after saving + self._current_message_file_ids = [] + current_tool_name = None + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.log_type == AgentLog.LogType.ROUND: + if current_agent_thought_id is None: + continue + + # Round finished - save LLM usage and answer + llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) + llm_result = output.data.get("llm_result") + final_answer = output.data.get("final_answer") + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=llm_result, + observation=None, + tool_invoke_meta=None, + answer=final_answer, + messages_ids=[], + llm_usage=llm_usage, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + except Exception: + # Re-raise any other exceptions + raise + + # Process final result + if isinstance(result, AgentResult): + final_answer = result.text + usage = result.usage or LLMUsage.empty_usage() + + # Publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Initialize system message + """ + if not prompt_messages and prompt_template: + return [ + SystemPromptMessage(content=prompt_template), + ] + + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + + return prompt_messages or [] + + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + As for now, gpt supports both fc and vision at the first iteration. + We need to remove the image messages from the prompt messages at the first iteration. + """ + prompt_messages = deepcopy(prompt_messages) + + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) + + return prompt_messages + + def _organize_prompt_messages(self): + # For ReAct strategy, use the agent prompt template + if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt: + prompt_template = self.config.prompt.first_prompt + else: + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" + + self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) + query_prompt_messages = self._organize_user_query(self.query or "", []) + + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=[*query_prompt_messages, *self._current_thoughts], + history_messages=self.history_prompt_messages, + memory=self.memory, + ).get_prompt() + + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] + if len(self._current_thoughts) != 0: + # clear messages after the first iteration + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + return prompt_messages diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c196dbbdf1..b59a9a3859 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,7 +5,7 @@ from typing import Union, cast from sqlalchemy import select -from core.agent.entities import AgentEntity, AgentToolEntity +from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -114,9 +114,20 @@ class BaseAgentRunner(AppRunner): features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] + self.model_features = features self.query: str | None = "" self._current_thoughts: list[PromptMessage] = [] + def build_execution_context(self) -> ExecutionContext: + """Build execution context.""" + return ExecutionContext( + user_id=self.user_id, + app_id=self.app_config.app_id, + conversation_id=self.conversation.id, + message_id=self.message.id, + tenant_id=self.tenant_id, + ) + def _repack_app_generate_entity( self, app_generate_entity: AgentChatAppGenerateEntity ) -> AgentChatAppGenerateEntity: diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py deleted file mode 100644 index b32e35d0ca..0000000000 --- a/api/core/agent/cot_agent_runner.py +++ /dev/null @@ -1,431 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping, Sequence -from typing import Any - -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.entities import AgentScratchpadUnit -from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from core.ops.ops_trace_manager import TraceQueueManager -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from models.model import Message - -logger = logging.getLogger(__name__) - - -class CotAgentRunner(BaseAgentRunner, ABC): - _is_first_iteration = True - _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] - _agent_scratchpad: list[AgentScratchpadUnit] - _instruction: str - _query: str - _prompt_messages_tools: Sequence[PromptMessageTool] - - def run( - self, - message: Message, - query: str, - inputs: Mapping[str, str], - ) -> Generator: - """ - Run Cot agent application - """ - - app_generate_entity = self.application_generate_entity - self._repack_app_generate_entity(app_generate_entity) - self._init_react_state(query) - - trace_manager = app_generate_entity.trace_manager - - # check model mode - if "Observation" not in app_generate_entity.model_conf.stop: - if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append("Observation") - - app_config = self.app_config - assert app_config.agent - - # init instruction - inputs = inputs or {} - instruction = app_config.prompt_template.simple_prompt_template or "" - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) - - iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - - # convert tools into ModelRuntime Tool format - tool_instances, prompt_messages_tools = self._init_prompt_tools() - self._prompt_messages_tools = prompt_messages_tools - - function_call_state = True - llm_usage: dict[str, LLMUsage | None] = {"usage": None} - final_answer = "" - prompt_messages: list = [] # Initialize prompt_messages - agent_thought_id = "" # Initialize agent_thought_id - - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): - if not final_llm_usage_dict["usage"]: - final_llm_usage_dict["usage"] = usage - else: - llm_usage = final_llm_usage_dict["usage"] - llm_usage.prompt_tokens += usage.prompt_tokens - llm_usage.completion_tokens += usage.completion_tokens - llm_usage.total_tokens += usage.total_tokens - llm_usage.prompt_price += usage.prompt_price - llm_usage.completion_price += usage.completion_price - llm_usage.total_price += usage.total_price - - model_instance = self.model_instance - - while function_call_state and iteration_step <= max_iteration_steps: - # continue to run until there is not any tool call - function_call_state = False - - if iteration_step == max_iteration_steps: - # the last iteration, remove all tools - self._prompt_messages_tools = [] - - message_file_ids: list[str] = [] - - agent_thought_id = self.create_agent_thought( - message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids - ) - - if iteration_step > 1: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # recalc llm max tokens - prompt_messages = self._organize_prompt_messages() - self.recalc_llm_max_tokens(self.model_config, prompt_messages) - # invoke model - chunks = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=app_generate_entity.model_conf.parameters, - tools=[], - stop=app_generate_entity.model_conf.stop, - stream=True, - user=self.user_id, - callbacks=[], - ) - - usage_dict: dict[str, LLMUsage | None] = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) - scratchpad = AgentScratchpadUnit( - agent_response="", - thought="", - action_str="", - observation="", - action=None, - ) - - # publish agent thought if it's first iteration - if iteration_step == 1: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - for chunk in react_chunks: - if isinstance(chunk, AgentScratchpadUnit.Action): - action = chunk - # detect action - assert scratchpad.agent_response is not None - scratchpad.agent_response += json.dumps(chunk.model_dump()) - scratchpad.action_str = json.dumps(chunk.model_dump()) - scratchpad.action = action - else: - assert scratchpad.agent_response is not None - scratchpad.agent_response += chunk - assert scratchpad.thought is not None - scratchpad.thought += chunk - yield LLMResultChunk( - model=self.model_config.model, - prompt_messages=prompt_messages, - system_fingerprint="", - delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), - ) - - assert scratchpad.thought is not None - scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" - self._agent_scratchpad.append(scratchpad) - - # get llm usage - if "usage" in usage_dict: - if usage_dict["usage"] is not None: - increase_usage(llm_usage, usage_dict["usage"]) - else: - usage_dict["usage"] = LLMUsage.empty_usage() - - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), - tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, - tool_invoke_meta={}, - thought=scratchpad.thought or "", - observation="", - answer=scratchpad.agent_response or "", - messages_ids=[], - llm_usage=usage_dict["usage"], - ) - - if not scratchpad.is_final(): - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - if not scratchpad.action: - # failed to extract action, return final answer directly - final_answer = "" - else: - if scratchpad.action.action_name.lower() == "final answer": - # action is final answer, return final answer directly - try: - if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False) - elif isinstance(scratchpad.action.action_input, str): - final_answer = scratchpad.action.action_input - else: - final_answer = f"{scratchpad.action.action_input}" - except TypeError: - final_answer = f"{scratchpad.action.action_input}" - else: - function_call_state = True - # action is tool call, invoke tool - tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( - action=scratchpad.action, - tool_instances=tool_instances, - message_file_ids=message_file_ids, - trace_manager=trace_manager, - ) - scratchpad.observation = tool_invoke_response - scratchpad.agent_response = tool_invoke_response - - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=scratchpad.action.action_name, - tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought or "", - observation={scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, - answer=scratchpad.agent_response, - messages_ids=message_file_ids, - llm_usage=usage_dict["usage"], - ) - - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # update prompt tool message - for prompt_tool in self._prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - - iteration_step += 1 - - yield LLMResultChunk( - model=model_instance.model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] - ), - system_fingerprint="", - ) - - # save agent thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name="", - tool_input={}, - tool_invoke_meta={}, - thought=final_answer, - observation={}, - answer=final_answer, - messages_ids=[], - ) - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) - - def _handle_invoke_action( - self, - action: AgentScratchpadUnit.Action, - tool_instances: Mapping[str, Tool], - message_file_ids: list[str], - trace_manager: TraceQueueManager | None = None, - ) -> tuple[str, ToolInvokeMeta]: - """ - handle invoke action - :param action: action - :param tool_instances: tool instances - :param message_file_ids: message file ids - :param trace_manager: trace manager - :return: observation, meta - """ - # action is tool call, invoke tool - tool_call_name = action.action_name - tool_call_args = action.action_input - tool_instance = tool_instances.get(tool_call_name) - - if not tool_instance: - answer = f"there is not a tool named {tool_call_name}" - return answer, ToolInvokeMeta.error_instance(answer) - - if isinstance(tool_call_args, str): - try: - tool_call_args = json.loads(tool_call_args) - except json.JSONDecodeError: - pass - - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback, - trace_manager=trace_manager, - ) - - # publish files - for message_file_id in message_files: - # publish message file - self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER - ) - # add message file ids - message_file_ids.append(message_file_id) - - return tool_invoke_response, tool_invoke_meta - - def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: - """ - convert dict to action - """ - return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: - """ - fill in inputs from external data tools - """ - for key, value in inputs.items(): - try: - instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) - except Exception: - continue - - return instruction - - def _init_react_state(self, query): - """ - init agent scratchpad - """ - self._query = query - self._agent_scratchpad = [] - self._historic_prompt_messages = self._organize_historic_prompt_messages() - - @abstractmethod - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - organize prompt messages - """ - - def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: - """ - format assistant message - """ - message = "" - for scratchpad in agent_scratchpad: - if scratchpad.is_final(): - message += f"Final Answer: {scratchpad.agent_response}" - else: - message += f"Thought: {scratchpad.thought}\n\n" - if scratchpad.action_str: - message += f"Action: {scratchpad.action_str}\n\n" - if scratchpad.observation: - message += f"Observation: {scratchpad.observation}\n\n" - - return message - - def _organize_historic_prompt_messages( - self, current_session_messages: list[PromptMessage] | None = None - ) -> list[PromptMessage]: - """ - organize historic prompt messages - """ - result: list[PromptMessage] = [] - scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit | None = None - - for message in self.history_prompt_messages: - if isinstance(message, AssistantPromptMessage): - if not current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad = AgentScratchpadUnit( - agent_response=message.content, - thought=message.content or "I am thinking about how to help you", - action_str="", - action=None, - observation=None, - ) - scratchpads.append(current_scratchpad) - if message.tool_calls: - try: - current_scratchpad.action = AgentScratchpadUnit.Action( - action_name=message.tool_calls[0].function.name, - action_input=json.loads(message.tool_calls[0].function.arguments), - ) - current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) - except Exception: - logger.exception("Failed to parse tool call from assistant message") - elif isinstance(message, ToolPromptMessage): - if current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad.observation = message.content - else: - raise NotImplementedError("expected str type") - elif isinstance(message, UserPromptMessage): - if scratchpads: - result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - scratchpads = [] - current_scratchpad = None - - result.append(message) - - if scratchpads: - result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - - historic_prompts = AgentHistoryPromptTransform( - model_config=self.model_config, - prompt_messages=current_session_messages or [], - history_messages=result, - memory=self.memory, - ).get_prompt() - return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py deleted file mode 100644 index 4d1d94eadc..0000000000 --- a/api/core/agent/cot_chat_agent_runner.py +++ /dev/null @@ -1,118 +0,0 @@ -import json - -from core.agent.cot_agent_runner import CotAgentRunner -from core.file import file_manager -from core.model_runtime.entities import ( - AssistantPromptMessage, - PromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.model_runtime.utils.encoders import jsonable_encoder - - -class CotChatAgentRunner(CotAgentRunner): - def _organize_system_prompt(self) -> SystemPromptMessage: - """ - Organize system prompt - """ - assert self.app_config.agent - assert self.app_config.agent.prompt - - prompt_entity = self.app_config.agent.prompt - if not prompt_entity: - raise ValueError("Agent prompt configuration is not set") - first_prompt = prompt_entity.first_prompt - - system_prompt = ( - first_prompt.replace("{{instruction}}", self._instruction) - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) - .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) - ) - - return SystemPromptMessage(content=system_prompt) - - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Organize user query - """ - if self.files: - # get image detail config - image_detail_config = ( - self.application_generate_entity.file_upload_config.image_config.detail - if ( - self.application_generate_entity.file_upload_config - and self.application_generate_entity.file_upload_config.image_config - ) - else None - ) - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - for file in self.files: - prompt_message_contents.append( - file_manager.to_prompt_message_content( - file, - image_detail_config=image_detail_config, - ) - ) - prompt_message_contents.append(TextPromptMessageContent(data=query)) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - Organize - """ - # organize system prompt - system_message = self._organize_system_prompt() - - # organize current assistant messages - agent_scratchpad = self._agent_scratchpad - if not agent_scratchpad: - assistant_messages = [] - else: - assistant_message = AssistantPromptMessage(content="") - assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str - for unit in agent_scratchpad: - if unit.is_final(): - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Final Answer: {unit.agent_response}" - else: - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Thought: {unit.thought}\n\n" - if unit.action_str: - assistant_message.content += f"Action: {unit.action_str}\n\n" - if unit.observation: - assistant_message.content += f"Observation: {unit.observation}\n\n" - - assistant_messages = [assistant_message] - - # query messages - query_messages = self._organize_user_query(self._query, []) - - if assistant_messages: - # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages( - [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] - ) - messages = [ - system_message, - *historic_messages, - *query_messages, - *assistant_messages, - UserPromptMessage(content="continue"), - ] - else: - # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages]) - messages = [system_message, *historic_messages, *query_messages] - - # join all messages - return messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py deleted file mode 100644 index da9a001d84..0000000000 --- a/api/core/agent/cot_completion_agent_runner.py +++ /dev/null @@ -1,87 +0,0 @@ -import json - -from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.utils.encoders import jsonable_encoder - - -class CotCompletionAgentRunner(CotAgentRunner): - def _organize_instruction_prompt(self) -> str: - """ - Organize instruction prompt - """ - if self.app_config.agent is None: - raise ValueError("Agent configuration is not set") - prompt_entity = self.app_config.agent.prompt - if prompt_entity is None: - raise ValueError("prompt entity is not set") - first_prompt = prompt_entity.first_prompt - - system_prompt = ( - first_prompt.replace("{{instruction}}", self._instruction) - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) - .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) - ) - - return system_prompt - - def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str: - """ - Organize historic prompt - """ - historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) - historic_prompt = "" - - for message in historic_prompt_messages: - if isinstance(message, UserPromptMessage): - historic_prompt += f"Question: {message.content}\n\n" - elif isinstance(message, AssistantPromptMessage): - if isinstance(message.content, str): - historic_prompt += message.content + "\n\n" - elif isinstance(message.content, list): - for content in message.content: - if not isinstance(content, TextPromptMessageContent): - continue - historic_prompt += content.data - - return historic_prompt - - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - Organize prompt messages - """ - # organize system prompt - system_prompt = self._organize_instruction_prompt() - - # organize historic prompt messages - historic_prompt = self._organize_historic_prompt() - - # organize current assistant messages - agent_scratchpad = self._agent_scratchpad - assistant_prompt = "" - for unit in agent_scratchpad or []: - if unit.is_final(): - assistant_prompt += f"Final Answer: {unit.agent_response}" - else: - assistant_prompt += f"Thought: {unit.thought}\n\n" - if unit.action_str: - assistant_prompt += f"Action: {unit.action_str}\n\n" - if unit.observation: - assistant_prompt += f"Observation: {unit.observation}\n\n" - - # query messages - query_prompt = f"Question: {self._query}" - - # join all messages - prompt = ( - system_prompt.replace("{{historic_messages}}", historic_prompt) - .replace("{{agent_scratchpad}}", assistant_prompt) - .replace("{{query}}", query_prompt) - ) - - return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 220feced1d..56319a14a3 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,3 +1,5 @@ +import uuid +from collections.abc import Mapping from enum import StrEnum from typing import Any, Union @@ -92,3 +94,94 @@ class AgentInvokeMessage(ToolInvokeMessage): """ pass + + +class ExecutionContext(BaseModel): + """Execution context containing trace and audit information. + + This context carries all the IDs and metadata that are not part of + the core business logic but needed for tracing, auditing, and + correlation purposes. + """ + + user_id: str | None = None + app_id: str | None = None + conversation_id: str | None = None + message_id: str | None = None + tenant_id: str | None = None + + @classmethod + def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext": + """Create a minimal context with only essential fields.""" + return cls(user_id=user_id) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for passing to legacy code.""" + return { + "user_id": self.user_id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "message_id": self.message_id, + "tenant_id": self.tenant_id, + } + + def with_updates(self, **kwargs) -> "ExecutionContext": + """Create a new context with updated fields.""" + data = self.to_dict() + data.update(kwargs) + + return ExecutionContext( + user_id=data.get("user_id"), + app_id=data.get("app_id"), + conversation_id=data.get("conversation_id"), + message_id=data.get("message_id"), + tenant_id=data.get("tenant_id"), + ) + + +class AgentLog(BaseModel): + """ + Agent Log. + """ + + class LogType(StrEnum): + """Type of agent log entry.""" + + ROUND = "round" # A complete iteration round + THOUGHT = "thought" # LLM thinking/reasoning + TOOL_CALL = "tool_call" # Tool invocation + + class LogMetadata(StrEnum): + STARTED_AT = "started_at" + FINISHED_AT = "finished_at" + ELAPSED_TIME = "elapsed_time" + TOTAL_PRICE = "total_price" + TOTAL_TOKENS = "total_tokens" + PROVIDER = "provider" + CURRENCY = "currency" + LLM_USAGE = "llm_usage" + + class LogStatus(StrEnum): + START = "start" + ERROR = "error" + SUCCESS = "success" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log") + label: str = Field(..., description="The label of the log") + log_type: LogType = Field(..., description="The type of the log") + parent_id: str | None = Field(default=None, description="Leave empty for root log") + error: str | None = Field(default=None, description="The error message") + status: LogStatus = Field(..., description="The status of the log") + data: Mapping[str, Any] = Field(..., description="Detailed log data") + metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log") + + +class AgentResult(BaseModel): + """ + Agent execution result. + """ + + text: str = Field(default="", description="The generated text") + files: list[Any] = Field(default_factory=list, description="Files produced during execution") + usage: Any | None = Field(default=None, description="LLM usage statistics") + finish_reason: str | None = Field(default=None, description="Reason for completion") diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py deleted file mode 100644 index dcc1326b33..0000000000 --- a/api/core/agent/fc_agent_runner.py +++ /dev/null @@ -1,465 +0,0 @@ -import json -import logging -from collections.abc import Generator -from copy import deepcopy -from typing import Any, Union - -from core.agent.base_agent_runner import BaseAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.file import file_manager -from core.model_runtime.entities import ( - AssistantPromptMessage, - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, - PromptMessage, - PromptMessageContentType, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from models.model import Message - -logger = logging.getLogger(__name__) - - -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: - """ - Run FunctionCall agent application - """ - self.query = query - app_generate_entity = self.application_generate_entity - - app_config = self.app_config - assert app_config is not None, "app_config is required" - assert app_config.agent is not None, "app_config.agent is required" - - # convert tools into ModelRuntime Tool format - tool_instances, prompt_messages_tools = self._init_prompt_tools() - - assert app_config.agent - - iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - - # continue to run until there is not any tool call - function_call_state = True - llm_usage: dict[str, LLMUsage | None] = {"usage": None} - final_answer = "" - prompt_messages: list = [] # Initialize prompt_messages - - # get tracing instance - trace_manager = app_generate_entity.trace_manager - - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): - if not final_llm_usage_dict["usage"]: - final_llm_usage_dict["usage"] = usage - else: - llm_usage = final_llm_usage_dict["usage"] - llm_usage.prompt_tokens += usage.prompt_tokens - llm_usage.completion_tokens += usage.completion_tokens - llm_usage.total_tokens += usage.total_tokens - llm_usage.prompt_price += usage.prompt_price - llm_usage.completion_price += usage.completion_price - llm_usage.total_price += usage.total_price - - model_instance = self.model_instance - - while function_call_state and iteration_step <= max_iteration_steps: - function_call_state = False - - if iteration_step == max_iteration_steps: - # the last iteration, remove all tools - prompt_messages_tools = [] - - message_file_ids: list[str] = [] - agent_thought_id = self.create_agent_thought( - message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids - ) - - # recalc llm max tokens - prompt_messages = self._organize_prompt_messages() - self.recalc_llm_max_tokens(self.model_config, prompt_messages) - # invoke model - chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=app_generate_entity.model_conf.parameters, - tools=prompt_messages_tools, - stop=app_generate_entity.model_conf.stop, - stream=self.stream_tool_call, - user=self.user_id, - callbacks=[], - ) - - tool_calls: list[tuple[str, str, dict[str, Any]]] = [] - - # save full response - response = "" - - # save tool call names and inputs - tool_call_names = "" - tool_call_inputs = "" - - current_llm_usage = None - - if isinstance(chunks, Generator): - is_first_chunk = True - for chunk in chunks: - if is_first_chunk: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - is_first_chunk = False - # check if there is any tool call - if self.check_tool_calls(chunk): - function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk) or []) - tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) - try: - tool_call_inputs = json.dumps( - {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False - ) - except TypeError: - # fallback: force ASCII to handle non-serializable objects - tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - - if chunk.delta.message and chunk.delta.message.content: - if isinstance(chunk.delta.message.content, list): - for content in chunk.delta.message.content: - response += content.data - else: - response += str(chunk.delta.message.content) - - if chunk.delta.usage: - increase_usage(llm_usage, chunk.delta.usage) - current_llm_usage = chunk.delta.usage - - yield chunk - else: - result = chunks - # check if there is any tool call - if self.check_blocking_tool_calls(result): - function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result) or []) - tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) - try: - tool_call_inputs = json.dumps( - {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False - ) - except TypeError: - # fallback: force ASCII to handle non-serializable objects - tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - - if result.usage: - increase_usage(llm_usage, result.usage) - current_llm_usage = result.usage - - if result.message and result.message.content: - if isinstance(result.message.content, list): - for content in result.message.content: - response += content.data - else: - response += str(result.message.content) - - if not result.message.content: - result.message.content = "" - - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - yield LLMResultChunk( - model=model_instance.model, - prompt_messages=result.prompt_messages, - system_fingerprint=result.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=result.message, - usage=result.usage, - ), - ) - - assistant_message = AssistantPromptMessage(content="", tool_calls=[]) - if tool_calls: - assistant_message.tool_calls = [ - AssistantPromptMessage.ToolCall( - id=tool_call[0], - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) - ), - ) - for tool_call in tool_calls - ] - else: - assistant_message.content = response - - self._current_thoughts.append(assistant_message) - - # save thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=tool_call_names, - tool_input=tool_call_inputs, - thought=response, - tool_invoke_meta=None, - observation=None, - answer=response, - messages_ids=[], - llm_usage=current_llm_usage, - ) - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - final_answer += response + "\n" - - # call tools - tool_responses = [] - for tool_call_id, tool_call_name, tool_call_args in tool_calls: - tool_instance = tool_instances.get(tool_call_name) - if not tool_instance: - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), - } - else: - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback, - trace_manager=trace_manager, - app_id=self.application_generate_entity.app_config.app_id, - message_id=self.message.id, - conversation_id=self.conversation.id, - ) - # publish files - for message_file_id in message_files: - # publish message file - self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER - ) - # add message file ids - message_file_ids.append(message_file_id) - - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict(), - } - - tool_responses.append(tool_response) - if tool_response["tool_response"] is not None: - self._current_thoughts.append( - ToolPromptMessage( - content=str(tool_response["tool_response"]), - tool_call_id=tool_call_id, - name=tool_call_name, - ) - ) - - if len(tool_responses) > 0: - # save agent thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name="", - tool_input="", - thought="", - tool_invoke_meta={ - tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses - }, - observation={ - tool_response["tool_call_name"]: tool_response["tool_response"] - for tool_response in tool_responses - }, - answer="", - messages_ids=message_file_ids, - ) - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # update prompt tool - for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - - iteration_step += 1 - - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) - - def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: - """ - Check if there is any tool call in llm result chunk - """ - if llm_result_chunk.delta.message.tool_calls: - return True - return False - - def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: - """ - Check if there is any blocking tool call in llm result - """ - if llm_result.message.tool_calls: - return True - return False - - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: - """ - Extract tool calls from llm result chunk - - Returns: - List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] - """ - tool_calls = [] - for prompt_message in llm_result_chunk.delta.message.tool_calls: - args = {} - if prompt_message.function.arguments != "": - args = json.loads(prompt_message.function.arguments) - - tool_calls.append( - ( - prompt_message.id, - prompt_message.function.name, - args, - ) - ) - - return tool_calls - - def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: - """ - Extract blocking tool calls from llm result - - Returns: - List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] - """ - tool_calls = [] - for prompt_message in llm_result.message.tool_calls: - args = {} - if prompt_message.function.arguments != "": - args = json.loads(prompt_message.function.arguments) - - tool_calls.append( - ( - prompt_message.id, - prompt_message.function.name, - args, - ) - ) - - return tool_calls - - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Initialize system message - """ - if not prompt_messages and prompt_template: - return [ - SystemPromptMessage(content=prompt_template), - ] - - if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: - prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - - return prompt_messages or [] - - def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Organize user query - """ - if self.files: - # get image detail config - image_detail_config = ( - self.application_generate_entity.file_upload_config.image_config.detail - if ( - self.application_generate_entity.file_upload_config - and self.application_generate_entity.file_upload_config.image_config - ) - else None - ) - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - for file in self.files: - prompt_message_contents.append( - file_manager.to_prompt_message_content( - file, - image_detail_config=image_detail_config, - ) - ) - prompt_message_contents.append(TextPromptMessageContent(data=query)) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - As for now, gpt supports both fc and vision at the first iteration. - We need to remove the image messages from the prompt messages at the first iteration. - """ - prompt_messages = deepcopy(prompt_messages) - - for prompt_message in prompt_messages: - if isinstance(prompt_message, UserPromptMessage): - if isinstance(prompt_message.content, list): - prompt_message.content = "\n".join( - [ - content.data - if content.type == PromptMessageContentType.TEXT - else "[image]" - if content.type == PromptMessageContentType.IMAGE - else "[file]" - for content in prompt_message.content - ] - ) - - return prompt_messages - - def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or "" - self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query or "", []) - - self.history_prompt_messages = AgentHistoryPromptTransform( - model_config=self.model_config, - prompt_messages=[*query_prompt_messages, *self._current_thoughts], - history_messages=self.history_prompt_messages, - memory=self.memory, - ).get_prompt() - - prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] - if len(self._current_thoughts) != 0: - # clear messages after the first iteration - prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) - return prompt_messages diff --git a/api/core/agent/patterns/README.md b/api/core/agent/patterns/README.md new file mode 100644 index 0000000000..f6437ba05a --- /dev/null +++ b/api/core/agent/patterns/README.md @@ -0,0 +1,67 @@ +# Agent Patterns + +A unified agent pattern module that provides common agent execution strategies for both Agent V2 nodes and Agent Applications in Dify. + +## Overview + +This module implements a strategy pattern for agent execution, automatically selecting the appropriate strategy based on model capabilities. It serves as the core engine for agent-based interactions across different components of the Dify platform. + +## Key Features + +### 1. Multiple Agent Strategies + +- **Function Call Strategy**: Leverages native function/tool calling capabilities of advanced LLMs (e.g., GPT-4, Claude) +- **ReAct Strategy**: Implements the ReAct (Reasoning + Acting) approach for models without native function calling support + +### 2. Automatic Strategy Selection + +The `StrategyFactory` intelligently selects the optimal strategy based on model features: + +- Models with `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL` capabilities → Function Call Strategy +- Other models → ReAct Strategy + +### 3. Unified Interface + +- Common base class (`AgentPattern`) ensures consistent behavior across strategies +- Seamless integration with both workflow nodes and standalone agent applications +- Standardized input/output formats for easy consumption + +### 4. Advanced Capabilities + +- **Streaming Support**: Real-time response streaming for better user experience +- **File Handling**: Built-in support for processing and managing files during agent execution +- **Iteration Control**: Configurable maximum iterations with safety limits (capped at 99) +- **Tool Management**: Flexible tool integration supporting various tool types +- **Context Propagation**: Execution context for tracing, auditing, and debugging + +## Architecture + +``` +agent/patterns/ +├── base.py # Abstract base class defining the agent pattern interface +├── function_call.py # Implementation using native LLM function calling +├── react.py # Implementation using ReAct prompting approach +└── strategy_factory.py # Factory for automatic strategy selection +``` + +## Usage + +The module is designed to be used by: + +1. **Agent V2 Nodes**: In workflow orchestration for complex agent tasks +1. **Agent Applications**: For standalone conversational agents +1. **Custom Implementations**: As a foundation for building specialized agent behaviors + +## Integration Points + +- **Model Runtime**: Interfaces with Dify's model runtime for LLM interactions +- **Tool System**: Integrates with the tool framework for external capabilities +- **Memory Management**: Compatible with conversation memory systems +- **File Management**: Handles file inputs/outputs during agent execution + +## Benefits + +1. **Consistency**: Unified implementation reduces code duplication and maintenance overhead +1. **Flexibility**: Easy to extend with new strategies or customize existing ones +1. **Performance**: Optimized for each model's capabilities to ensure best performance +1. **Reliability**: Built-in safety mechanisms and error handling diff --git a/api/core/agent/patterns/__init__.py b/api/core/agent/patterns/__init__.py new file mode 100644 index 0000000000..8a3b125533 --- /dev/null +++ b/api/core/agent/patterns/__init__.py @@ -0,0 +1,19 @@ +"""Agent patterns module. + +This module provides different strategies for agent execution: +- FunctionCallStrategy: Uses native function/tool calling +- ReActStrategy: Uses ReAct (Reasoning + Acting) approach +- StrategyFactory: Factory for creating strategies based on model features +""" + +from .base import AgentPattern +from .function_call import FunctionCallStrategy +from .react import ReActStrategy +from .strategy_factory import StrategyFactory + +__all__ = [ + "AgentPattern", + "FunctionCallStrategy", + "ReActStrategy", + "StrategyFactory", +] diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py new file mode 100644 index 0000000000..9f010bed6a --- /dev/null +++ b/api/core/agent/patterns/base.py @@ -0,0 +1,444 @@ +"""Base class for agent strategies.""" + +from __future__ import annotations + +import json +import re +import time +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator +from typing import TYPE_CHECKING, Any + +from core.agent.entities import AgentLog, AgentResult, ExecutionContext +from core.file import File +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + +# Type alias for tool invoke hook +# Returns: (response_content, message_file_ids, tool_invoke_meta) +ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]] + + +class AgentPattern(ABC): + """Base class for agent execution strategies.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + ): + """Initialize the agent strategy.""" + self.model_instance = model_instance + self.tools = tools + self.context = context + self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations + self.workflow_call_depth = workflow_call_depth + self.files: list[File] = files + self.tool_invoke_hook = tool_invoke_hook + + @abstractmethod + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the agent strategy.""" + pass + + def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None: + """Accumulate LLM usage statistics.""" + if not total_usage.get("usage"): + # Create a copy to avoid modifying the original + total_usage["usage"] = LLMUsage( + prompt_tokens=delta_usage.prompt_tokens, + prompt_unit_price=delta_usage.prompt_unit_price, + prompt_price_unit=delta_usage.prompt_price_unit, + prompt_price=delta_usage.prompt_price, + completion_tokens=delta_usage.completion_tokens, + completion_unit_price=delta_usage.completion_unit_price, + completion_price_unit=delta_usage.completion_price_unit, + completion_price=delta_usage.completion_price, + total_tokens=delta_usage.total_tokens, + total_price=delta_usage.total_price, + currency=delta_usage.currency, + latency=delta_usage.latency, + ) + else: + current: LLMUsage = total_usage["usage"] + current.prompt_tokens += delta_usage.prompt_tokens + current.completion_tokens += delta_usage.completion_tokens + current.total_tokens += delta_usage.total_tokens + current.prompt_price += delta_usage.prompt_price + current.completion_price += delta_usage.completion_price + current.total_price += delta_usage.total_price + + def _extract_content(self, content: Any) -> str: + """Extract text content from message content.""" + if isinstance(content, list): + # Content items are PromptMessageContentUnionTypes + text_parts = [] + for c in content: + # Check if it's a TextPromptMessageContent (which has data attribute) + if isinstance(c, TextPromptMessageContent): + text_parts.append(c.data) + return "".join(text_parts) + return str(content) + + def _has_tool_calls(self, chunk: LLMResultChunk) -> bool: + """Check if chunk contains tool calls.""" + # LLMResultChunk always has delta attribute + return bool(chunk.delta.message and chunk.delta.message.tool_calls) + + def _has_tool_calls_result(self, result: LLMResult) -> bool: + """Check if result contains tool calls (non-streaming).""" + # LLMResult always has message attribute + return bool(result.message and result.message.tool_calls) + + def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from streaming chunk.""" + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + if chunk.delta.message and chunk.delta.message.tool_calls: + for tool_call in chunk.delta.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from non-streaming result.""" + tool_calls = [] + if result.message and result.message.tool_calls: + for tool_call in result.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_text_from_message(self, message: PromptMessage) -> str: + """Extract text content from a prompt message.""" + # PromptMessage always has content attribute + content = message.content + if isinstance(content, str): + return content + elif isinstance(content, list): + # Extract text from content list + text_parts = [] + for item in content: + if isinstance(item, TextPromptMessageContent): + text_parts.append(item.data) + return " ".join(text_parts) + return "" + + def _create_log( + self, + label: str, + log_type: AgentLog.LogType, + status: AgentLog.LogStatus, + data: dict[str, Any] | None = None, + parent_id: str | None = None, + extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None, + ) -> AgentLog: + """Create a new AgentLog with standard metadata.""" + metadata = { + AgentLog.LogMetadata.STARTED_AT: time.perf_counter(), + } + if extra_metadata: + metadata.update(extra_metadata) + + return AgentLog( + label=label, + log_type=log_type, + status=status, + data=data or {}, + parent_id=parent_id, + metadata=metadata, + ) + + def _finish_log( + self, + log: AgentLog, + data: dict[str, Any] | None = None, + usage: LLMUsage | None = None, + ) -> AgentLog: + """Finish an AgentLog by updating its status and metadata.""" + log.status = AgentLog.LogStatus.SUCCESS + + if data is not None: + log.data = data + + # Calculate elapsed time + started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter()) + finished_at = time.perf_counter() + + # Update metadata + log.metadata = { + **log.metadata, + AgentLog.LogMetadata.FINISHED_AT: finished_at, + AgentLog.LogMetadata.ELAPSED_TIME: finished_at - started_at, + } + + # Add usage information if provided + if usage: + log.metadata.update( + { + AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price, + AgentLog.LogMetadata.CURRENCY: usage.currency, + AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens, + AgentLog.LogMetadata.LLM_USAGE: usage, + } + ) + + return log + + def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]: + """ + Replace file references in tool arguments with actual File objects. + + Args: + tool_args: Dictionary of tool arguments + + Returns: + Updated tool arguments with file references replaced + """ + # Process each argument in the dictionary + processed_args: dict[str, Any] = {} + for key, value in tool_args.items(): + processed_args[key] = self._process_file_reference(value) + return processed_args + + def _process_file_reference(self, data: Any) -> Any: + """ + Recursively process data to replace file references. + Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...]. + + Args: + data: The data to process (can be dict, list, str, or other types) + + Returns: + Processed data with file references replaced + """ + single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$") + multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$") + + if isinstance(data, dict): + # Process dictionary recursively + return {key: self._process_file_reference(value) for key, value in data.items()} + elif isinstance(data, list): + # Process list recursively + return [self._process_file_reference(item) for item in data] + elif isinstance(data, str): + # Check for single file pattern [File: file_id] + single_match = single_file_pattern.match(data.strip()) + if single_match: + file_id = single_match.group(1).strip() + # Find the file in self.files + for file in self.files: + if file.id and str(file.id) == file_id: + return file + # If file not found, return original value + return data + + # Check for multiple files pattern [Files: file_id1, file_id2, ...] + multiple_match = multiple_files_pattern.match(data.strip()) + if multiple_match: + file_ids_str = multiple_match.group(1).strip() + # Split by comma and strip whitespace + file_ids = [fid.strip() for fid in file_ids_str.split(",")] + + # Find all matching files + matched_files: list[File] = [] + for file_id in file_ids: + for file in self.files: + if file.id and str(file.id) == file_id: + matched_files.append(file) + break + + # Return list of files if any were found, otherwise return original + return matched_files or data + + return data + else: + # Return other types as-is + return data + + def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk: + """Create a text chunk for streaming.""" + return LLMResultChunk( + model=self.model_instance.model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + usage=None, + ), + system_fingerprint="", + ) + + def _invoke_tool( + self, + tool_instance: Tool, + tool_args: dict[str, Any], + tool_name: str, + ) -> tuple[str, list[File], ToolInvokeMeta | None]: + """ + Invoke a tool and collect its response. + + Args: + tool_instance: The tool instance to invoke + tool_args: Tool arguments + tool_name: Name of the tool + + Returns: + Tuple of (response_content, tool_files, tool_invoke_meta) + """ + # Process tool_args to replace file references with actual File objects + tool_args = self._replace_file_references(tool_args) + + # If a tool invoke hook is set, use it instead of generic_invoke + if self.tool_invoke_hook: + response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name) + # Note: message_file_ids are stored in DB, we don't convert them to File objects here + # The caller (AgentAppRunner) handles file publishing + return response_content, [], tool_invoke_meta + + # Default: use generic_invoke for workflow scenarios + # Import here to avoid circular import + from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine + + tool_response = ToolEngine().generic_invoke( + tool=tool_instance, + tool_parameters=tool_args, + user_id=self.context.user_id or "", + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, + app_id=self.context.app_id, + conversation_id=self.context.conversation_id, + message_id=self.context.message_id, + ) + + # Collect response and files + response_content = "" + tool_files: list[File] = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(response.message, ToolInvokeMessage.TextMessage) + response_content += response.message.text + + elif response.type == ToolInvokeMessage.MessageType.LINK: + # Handle link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Link: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE: + # Handle image URL messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK: + # Handle image link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK: + # Handle binary file link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + filename = response.meta.get("filename", "file") if response.meta else "file" + response_content += f"[File: {filename} - {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.JSON: + # Handle JSON messages + if isinstance(response.message, ToolInvokeMessage.JsonMessage): + response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2) + + elif response.type == ToolInvokeMessage.MessageType.BLOB: + # Handle blob messages - convert to text representation + if isinstance(response.message, ToolInvokeMessage.BlobMessage): + mime_type = ( + response.meta.get("mime_type", "application/octet-stream") + if response.meta + else "application/octet-stream" + ) + size = len(response.message.blob) + response_content += f"[Binary data: {mime_type}, size: {size} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + # Handle variable messages + if isinstance(response.message, ToolInvokeMessage.VariableMessage): + var_name = response.message.variable_name + var_value = response.message.variable_value + if isinstance(var_value, str): + response_content += var_value + else: + response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]" + + elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: + # Handle blob chunk messages - these are parts of a larger blob + if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage): + response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: + # Handle retriever resources messages + if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage): + response_content += response.message.context + + elif response.type == ToolInvokeMessage.MessageType.FILE: + # Extract file from meta + if response.meta and "file" in response.meta: + file = response.meta["file"] + if isinstance(file, File): + # Check if file is for model or tool output + if response.meta.get("target") == "self": + # File is for model - add to files for next prompt + self.files.append(file) + response_content += f"File '{file.filename}' has been loaded into your context." + else: + # File is tool output + tool_files.append(file) + + return response_content, tool_files, None + + def _find_tool_by_name(self, tool_name: str) -> Tool | None: + """Find a tool instance by its name.""" + for tool in self.tools: + if tool.entity.identity.name == tool_name: + return tool + return None + + def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]: + """Convert tools to prompt message format.""" + prompt_tools: list[PromptMessageTool] = [] + for tool in self.tools: + prompt_tools.append(tool.to_prompt_message_tool()) + return prompt_tools + + def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None: + """Initialize usage tracking with empty usage if not set.""" + if "usage" not in llm_usage or llm_usage["usage"] is None: + llm_usage["usage"] = LLMUsage.empty_usage() diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py new file mode 100644 index 0000000000..2c8664c419 --- /dev/null +++ b/api/core/agent/patterns/function_call.py @@ -0,0 +1,273 @@ +"""Function Call strategy implementation.""" + +import json +from collections.abc import Generator +from typing import Any, Union + +from core.agent.entities import AgentLog, AgentResult +from core.file import File +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, +) +from core.tools.entities.tool_entities import ToolInvokeMeta + +from .base import AgentPattern + + +class FunctionCallStrategy(AgentPattern): + """Function Call strategy using model's native tool calling capability.""" + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the function call agent strategy.""" + # Convert tools to prompt format + prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format() + + # Initialize tracking + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + function_call_state: bool = True + total_usage: dict[str, LLMUsage | None] = {"usage": None} + messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy + final_text: str = "" + finish_reason: str | None = None + output_files: list[File] = [] # Track files produced by tools + + while function_call_state and iteration_step <= max_iterations: + function_call_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"round_index": iteration_step}, + ) + yield round_log + # On last iteration, remove tools to force final answer + current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools + model_log = self._create_log( + label=f"{self.model_instance.model} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, LLMUsage | None] = {"usage": None} + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages, + model_parameters=model_parameters, + tools=current_tools, + stop=stop, + stream=stream, + user=self.context.user_id, + callbacks=[], + ) + + # Process response + tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log + ) + messages.append(self._create_assistant_message(response_content, tool_calls)) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update final text if no tool calls (this is likely the final answer) + if not tool_calls: + final_text = response_content + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Process tool calls + tool_outputs: dict[str, str] = {} + if tool_calls: + function_call_state = True + # Execute tools + for tool_call_id, tool_name, tool_args in tool_calls: + tool_response, tool_files, _ = yield from self._handle_tool_call( + tool_name, tool_args, tool_call_id, messages, round_log + ) + tool_outputs[tool_name] = tool_response + # Track files produced by tools + output_files.extend(tool_files) + yield self._finish_log( + round_log, + data={ + "llm_result": response_content, + "tool_calls": [ + {"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls + ] + if tool_calls + else [], + "final_answer": final_text if not function_call_state else None, + }, + usage=round_usage.get("usage"), + ) + iteration_step += 1 + + # Return final result + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, + files=output_files, + usage=total_usage.get("usage") or LLMUsage.empty_usage(), + finish_reason=finish_reason, + ) + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, LLMUsage | None], + start_log: AgentLog, + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[list[tuple[str, str, dict[str, Any]]], str, str | None], + ]: + """Handle LLM response chunks and extract tool calls and content. + + Returns a tuple of (tool_calls, response_content, finish_reason). + """ + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + response_content: str = "" + finish_reason: str | None = None + if isinstance(chunks, Generator): + # Streaming response + for chunk in chunks: + # Extract tool calls + if self._has_tool_calls(chunk): + tool_calls.extend(self._extract_tool_calls(chunk)) + + # Extract content + if chunk.delta.message and chunk.delta.message.content: + response_content += self._extract_content(chunk.delta.message.content) + + # Track usage + if chunk.delta.usage: + self._accumulate_usage(llm_usage, chunk.delta.usage) + + # Capture finish reason + if chunk.delta.finish_reason: + finish_reason = chunk.delta.finish_reason + + yield chunk + else: + # Non-streaming response + result: LLMResult = chunks + + if self._has_tool_calls_result(result): + tool_calls.extend(self._extract_tool_calls_result(result)) + + if result.message and result.message.content: + response_content += self._extract_content(result.message.content) + + if result.usage: + self._accumulate_usage(llm_usage, result.usage) + + # Convert to streaming format + yield LLMResultChunk( + model=result.model, + prompt_messages=result.prompt_messages, + delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), + ) + yield self._finish_log( + start_log, + data={ + "result": response_content, + }, + usage=llm_usage.get("usage"), + ) + return tool_calls, response_content, finish_reason + + def _create_assistant_message( + self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None + ) -> AssistantPromptMessage: + """Create assistant message with tool calls.""" + if tool_calls is None: + return AssistantPromptMessage(content=content) + return AssistantPromptMessage( + content=content or "", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=tc[0], + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])), + ) + for tc in tool_calls + ], + ) + + def _handle_tool_call( + self, + tool_name: str, + tool_args: dict[str, Any], + tool_call_id: str, + messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]: + """Handle a single tool call and return response with files and meta.""" + # Find tool + tool_instance = self._find_tool_by_name(tool_name) + if not tool_instance: + raise ValueError(f"Tool {tool_name} not found") + + # Create tool call log + tool_call_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + ) + yield tool_call_log + + # Invoke tool using base class method + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name) + + yield self._finish_log( + tool_call_log, + data={ + **tool_call_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + final_content = response_content or "Tool executed successfully" + # Add tool response to messages + messages.append( + ToolPromptMessage( + content=final_content, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + return response_content, tool_files, tool_invoke_meta diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py new file mode 100644 index 0000000000..46a0dbd61e --- /dev/null +++ b/api/core/agent/patterns/react.py @@ -0,0 +1,402 @@ +"""ReAct strategy implementation.""" + +from __future__ import annotations + +import json +from collections.abc import Generator +from typing import TYPE_CHECKING, Any, Union + +from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext +from core.agent.output_parser.cot_output_parser import CotAgentOutputParser +from core.file import File +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + SystemPromptMessage, +) + +from .base import AgentPattern, ToolInvokeHook + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class ReActStrategy(AgentPattern): + """ReAct strategy using reasoning and acting approach.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ): + """Initialize the ReAct strategy with instruction support.""" + super().__init__( + model_instance=model_instance, + tools=tools, + context=context, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + files=files, + tool_invoke_hook=tool_invoke_hook, + ) + self.instruction = instruction + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the ReAct agent strategy.""" + # Initialize tracking + agent_scratchpad: list[AgentScratchpadUnit] = [] + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + react_state: bool = True + total_usage: dict[str, Any] = {"usage": None} + output_files: list[File] = [] # Track files produced by tools + final_text: str = "" + finish_reason: str | None = None + + # Add "Observation" to stop sequences + if "Observation" not in stop: + stop = stop.copy() + stop.append("Observation") + + while react_state and iteration_step <= max_iterations: + react_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"round_index": iteration_step}, + ) + yield round_log + + # Build prompt with/without tools based on iteration + include_tools = iteration_step < max_iterations + current_messages = self._build_prompt_with_react_format( + prompt_messages, agent_scratchpad, include_tools, self.instruction + ) + + model_log = self._create_log( + label=f"{self.model_instance.model} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, Any] = {"usage": None} + + # Use current messages directly (files are handled by base class if needed) + messages_to_use = current_messages + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages_to_use, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=self.context.user_id or "", + callbacks=[], + ) + + # Process response + scratchpad, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log, current_messages + ) + agent_scratchpad.append(scratchpad) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Check if we have an action to execute + if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": + react_state = True + # Execute tool + observation, tool_files = yield from self._handle_tool_call( + scratchpad.action, current_messages, round_log + ) + scratchpad.observation = observation + # Track files produced by tools + output_files.extend(tool_files) + + # Add observation to scratchpad for display + yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages) + else: + # Extract final answer + if scratchpad.action and scratchpad.action.action_input: + final_answer = scratchpad.action.action_input + if isinstance(final_answer, dict): + final_answer = json.dumps(final_answer, ensure_ascii=False) + final_text = str(final_answer) + elif scratchpad.thought: + # If no action but we have thought, use thought as final answer + final_text = scratchpad.thought + + yield self._finish_log( + round_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + "observation": scratchpad.observation or None, + "final_answer": final_text if not react_state else None, + }, + usage=round_usage.get("usage"), + ) + iteration_step += 1 + + # Return final result + + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason + ) + + def _build_prompt_with_react_format( + self, + original_messages: list[PromptMessage], + agent_scratchpad: list[AgentScratchpadUnit], + include_tools: bool = True, + instruction: str = "", + ) -> list[PromptMessage]: + """Build prompt messages with ReAct format.""" + # Copy messages to avoid modifying original + messages = list(original_messages) + + # Find and update the system prompt that should already exist + system_prompt_found = False + for i, msg in enumerate(messages): + if isinstance(msg, SystemPromptMessage): + system_prompt_found = True + # The system prompt from frontend already has the template, just replace placeholders + + # Format tools + tools_str = "" + tool_names = [] + if include_tools and self.tools: + # Convert tools to prompt message tools format + prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools] + tool_names = [tool.name for tool in prompt_tools] + + # Format tools as JSON for comprehensive information + from core.model_runtime.utils.encoders import jsonable_encoder + + tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2) + tool_names_str = ", ".join(f'"{name}"' for name in tool_names) + else: + tools_str = "No tools available" + tool_names_str = "" + + # Replace placeholders in the existing system prompt + updated_content = msg.content + assert isinstance(updated_content, str) + updated_content = updated_content.replace("{{instruction}}", instruction) + updated_content = updated_content.replace("{{tools}}", tools_str) + updated_content = updated_content.replace("{{tool_names}}", tool_names_str) + + # Create new SystemPromptMessage with updated content + messages[i] = SystemPromptMessage(content=updated_content) + break + + # If no system prompt found, that's unexpected but add scratchpad anyway + if not system_prompt_found: + # This shouldn't happen if frontend is working correctly + pass + + # Format agent scratchpad + scratchpad_str = "" + if agent_scratchpad: + scratchpad_parts: list[str] = [] + for unit in agent_scratchpad: + if unit.thought: + scratchpad_parts.append(f"Thought: {unit.thought}") + if unit.action_str: + scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```") + if unit.observation: + scratchpad_parts.append(f"Observation: {unit.observation}") + scratchpad_str = "\n".join(scratchpad_parts) + + # If there's a scratchpad, append it to the last message + if scratchpad_str: + messages.append(AssistantPromptMessage(content=scratchpad_str)) + + return messages + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, Any], + model_log: AgentLog, + current_messages: list[PromptMessage], + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[AgentScratchpadUnit, str | None], + ]: + """Handle LLM response chunks and extract action/thought. + + Returns a tuple of (scratchpad_unit, finish_reason). + """ + usage_dict: dict[str, Any] = {} + + # Convert non-streaming to streaming format if needed + if isinstance(chunks, LLMResult): + # Create a generator from the LLMResult + def result_to_chunks() -> Generator[LLMResultChunk, None, None]: + yield LLMResultChunk( + model=chunks.model, + prompt_messages=chunks.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=chunks.message, + usage=chunks.usage, + finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do + ), + system_fingerprint=chunks.system_fingerprint or "", + ) + + streaming_chunks = result_to_chunks() + else: + streaming_chunks = chunks + + react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict) + + # Initialize scratchpad unit + scratchpad = AgentScratchpadUnit( + agent_response="", + thought="", + action_str="", + observation="", + action=None, + ) + + finish_reason: str | None = None + + # Process chunks + for chunk in react_chunks: + if isinstance(chunk, AgentScratchpadUnit.Action): + # Action detected + action_str = json.dumps(chunk.model_dump()) + scratchpad.agent_response = (scratchpad.agent_response or "") + action_str + scratchpad.action_str = action_str + scratchpad.action = chunk + + yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) + else: + # Text chunk + chunk_text = str(chunk) + scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text + scratchpad.thought = (scratchpad.thought or "") + chunk_text + + yield self._create_text_chunk(chunk_text, current_messages) + + # Update usage + if usage_dict.get("usage"): + if llm_usage.get("usage"): + self._accumulate_usage(llm_usage, usage_dict["usage"]) + else: + llm_usage["usage"] = usage_dict["usage"] + + # Clean up thought + scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you" + + # Finish model log + yield self._finish_log( + model_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + }, + usage=llm_usage.get("usage"), + ) + + return scratchpad, finish_reason + + def _handle_tool_call( + self, + action: AgentScratchpadUnit.Action, + prompt_messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File]]]: + """Handle tool call and return observation with files.""" + tool_name = action.action_name + tool_args: dict[str, Any] | str = action.action_input + + # Start tool log + tool_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + ) + yield tool_log + + # Find tool instance + tool_instance = self._find_tool_by_name(tool_name) + if not tool_instance: + # Finish tool log with error + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "error": f"Tool {tool_name} not found", + }, + ) + return f"Tool {tool_name} not found", [] + + # Ensure tool_args is a dict + tool_args_dict: dict[str, Any] + if isinstance(tool_args, str): + try: + tool_args_dict = json.loads(tool_args) + except json.JSONDecodeError: + tool_args_dict = {"input": tool_args} + elif not isinstance(tool_args, dict): + tool_args_dict = {"input": str(tool_args)} + else: + tool_args_dict = tool_args + + # Invoke tool using base class method + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name) + + # Finish tool log + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + + return response_content or "Tool executed successfully", tool_files diff --git a/api/core/agent/patterns/strategy_factory.py b/api/core/agent/patterns/strategy_factory.py new file mode 100644 index 0000000000..ad26075291 --- /dev/null +++ b/api/core/agent/patterns/strategy_factory.py @@ -0,0 +1,107 @@ +"""Strategy factory for creating agent strategies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.agent.entities import AgentEntity, ExecutionContext +from core.file.models import File +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelFeature + +from .base import AgentPattern, ToolInvokeHook +from .function_call import FunctionCallStrategy +from .react import ReActStrategy + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class StrategyFactory: + """Factory for creating agent strategies based on model features.""" + + # Tool calling related features + TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL} + + @staticmethod + def create_strategy( + model_features: list[ModelFeature], + model_instance: ModelInstance, + context: ExecutionContext, + tools: list[Tool], + files: list[File], + max_iterations: int = 10, + workflow_call_depth: int = 0, + agent_strategy: AgentEntity.Strategy | None = None, + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ) -> AgentPattern: + """ + Create an appropriate strategy based on model features. + + Args: + model_features: List of model features/capabilities + model_instance: Model instance to use + context: Execution context containing trace/audit information + tools: Available tools + files: Available files + max_iterations: Maximum iterations for the strategy + workflow_call_depth: Depth of workflow calls + agent_strategy: Optional explicit strategy override + tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke) + instruction: Optional instruction for ReAct strategy + + Returns: + AgentStrategy instance + """ + # If explicit strategy is provided and it's Function Calling, try to use it if supported + if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING: + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + # Fallback to ReAct if FC is requested but not supported + + # If explicit strategy is Chain of Thought (ReAct) + if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Default auto-selection logic + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + # Model supports native function calling + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + else: + # Use ReAct strategy for models without function calling + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index b297f3ff20..8e920f369a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,7 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass, field from threading import Thread from typing import Any, Union @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( + ChunkType, MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueAgentLogEvent, @@ -71,13 +73,115 @@ from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models import Account, Conversation, EndUser, Message, MessageFile +from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile from models.enums import CreatorUserRole from models.workflow import Workflow, WorkflowNodeExecutionModel logger = logging.getLogger(__name__) +@dataclass +class StreamEventBuffer: + """ + Buffer for recording stream events in order to reconstruct the generation sequence. + Records the exact order of text chunks, thoughts, and tool calls as they stream. + """ + + # Accumulated reasoning content (each thought block is a separate element) + reasoning_content: list[str] = field(default_factory=list) + # Current reasoning buffer (accumulates until we see a different event type) + _current_reasoning: str = "" + # Tool calls with their details + tool_calls: list[dict] = field(default_factory=list) + # Tool call ID to index mapping for updating results + _tool_call_id_map: dict[str, int] = field(default_factory=dict) + # Sequence of events in stream order + sequence: list[dict] = field(default_factory=list) + # Current position in answer text + _content_position: int = 0 + # Track last event type to detect transitions + _last_event_type: str | None = None + + def _flush_current_reasoning(self) -> None: + """Flush accumulated reasoning to the list and add to sequence.""" + if self._current_reasoning.strip(): + self.reasoning_content.append(self._current_reasoning.strip()) + self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1}) + self._current_reasoning = "" + + def record_text_chunk(self, text: str) -> None: + """Record a text chunk event.""" + if not text: + return + + # Flush any pending reasoning first + if self._last_event_type == "thought": + self._flush_current_reasoning() + + text_len = len(text) + start_pos = self._content_position + + # If last event was also content, extend it; otherwise create new + if self.sequence and self.sequence[-1].get("type") == "content": + self.sequence[-1]["end"] = start_pos + text_len + else: + self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len}) + + self._content_position += text_len + self._last_event_type = "content" + + def record_thought_chunk(self, text: str) -> None: + """Record a thought/reasoning chunk event.""" + if not text: + return + + # Accumulate thought content + self._current_reasoning += text + self._last_event_type = "thought" + + def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None: + """Record a tool call event.""" + # Flush any pending reasoning first + if self._last_event_type == "thought": + self._flush_current_reasoning() + + # Check if this tool call already exists (we might get multiple chunks) + if tool_call_id in self._tool_call_id_map: + idx = self._tool_call_id_map[tool_call_id] + # Update arguments if provided + if tool_arguments: + self.tool_calls[idx]["arguments"] = tool_arguments + else: + # New tool call + tool_call = { + "id": tool_call_id or "", + "name": tool_name or "", + "arguments": tool_arguments or "", + "result": "", + } + self.tool_calls.append(tool_call) + idx = len(self.tool_calls) - 1 + self._tool_call_id_map[tool_call_id] = idx + self.sequence.append({"type": "tool_call", "index": idx}) + + self._last_event_type = "tool_call" + + def record_tool_result(self, tool_call_id: str, result: str) -> None: + """Record a tool result event (update existing tool call).""" + if tool_call_id in self._tool_call_id_map: + idx = self._tool_call_id_map[tool_call_id] + self.tool_calls[idx]["result"] = result + + def finalize(self) -> None: + """Finalize the buffer, flushing any pending data.""" + if self._last_event_type == "thought": + self._flush_current_reasoning() + + def has_data(self) -> bool: + """Check if there's any meaningful data recorded.""" + return bool(self.reasoning_content or self.tool_calls or self.sequence) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -145,6 +249,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id: str = "" self._draft_var_saver_factory = draft_var_saver_factory self._graph_runtime_state: GraphRuntimeState | None = None + # Stream event buffer for recording generation sequence + self._stream_buffer = StreamEventBuffer() self._seed_graph_runtime_state_from_queue_manager() def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: @@ -384,7 +490,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: - """Handle text chunk events.""" + """Handle text chunk events and record to stream buffer for sequence reconstruction.""" delta_text = event.text if delta_text is None: return @@ -406,9 +512,37 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if tts_publisher and queue_message: tts_publisher.publish(queue_message) + # Record stream event based on chunk type + chunk_type = event.chunk_type or ChunkType.TEXT + match chunk_type: + case ChunkType.TEXT: + self._stream_buffer.record_text_chunk(delta_text) + case ChunkType.THOUGHT: + self._stream_buffer.record_thought_chunk(delta_text) + case ChunkType.TOOL_CALL: + self._stream_buffer.record_tool_call( + tool_call_id=event.tool_call_id or "", + tool_name=event.tool_name or "", + tool_arguments=event.tool_arguments or "", + ) + case ChunkType.TOOL_RESULT: + self._stream_buffer.record_tool_result( + tool_call_id=event.tool_call_id or "", + result=delta_text, + ) + self._task_state.answer += delta_text yield self._message_cycle_manager.message_to_stream_response( - answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector + answer=delta_text, + message_id=self._message_id, + from_variable_selector=event.from_variable_selector, + chunk_type=event.chunk_type.value if event.chunk_type else None, + tool_call_id=event.tool_call_id, + tool_name=event.tool_name, + tool_arguments=event.tool_arguments, + tool_files=event.tool_files, + tool_error=event.tool_error, + round_index=event.round_index, ) def _handle_iteration_start_event( @@ -842,6 +976,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ] session.add_all(message_files) + # Save merged LLM generation detail from all LLM nodes + self._save_generation_detail(session=session, message=message) # Trigger MESSAGE_TRACE for tracing integrations if trace_manager: trace_manager.add_trace_task( @@ -850,6 +986,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) ) + def _save_generation_detail(self, *, session: Session, message: Message) -> None: + """ + Save LLM generation detail for Chatflow using stream event buffer. + The buffer records the exact order of events as they streamed, + allowing accurate reconstruction of the generation sequence. + """ + # Finalize the stream buffer to flush any pending data + self._stream_buffer.finalize() + + # Only save if there's meaningful data + if not self._stream_buffer.has_data(): + return + + reasoning_content = self._stream_buffer.reasoning_content + tool_calls = self._stream_buffer.tool_calls + sequence = self._stream_buffer.sequence + + # Check if generation detail already exists for this message + existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first() + + if existing: + existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None + existing.tool_calls = json.dumps(tool_calls) if tool_calls else None + existing.sequence = json.dumps(sequence) if sequence else None + else: + generation_detail = LLMGenerationDetail( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + message_id=message.id, + reasoning_content=json.dumps(reasoning_content) if reasoning_content else None, + tool_calls=json.dumps(tool_calls) if tool_calls else None, + sequence=json.dumps(sequence) if sequence else None, + ) + session.add(generation_detail) + def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None: """ Extract model provider and model_id from workflow node executions. diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2760466a3b..f5cf7a2c56 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -3,10 +3,8 @@ from typing import cast from sqlalchemy import select -from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner +from core.agent.agent_app_runner import AgentAppRunner from core.agent.entities import AgentEntity -from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner @@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationError from extensions.ext_database import db @@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner): raise ValueError("Message not found") db.session.close() - runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] - # start agent runner - if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - # check LLM mode - if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT: - runner_cls = CotChatAgentRunner - elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION: - runner_cls = CotCompletionAgentRunner - else: - raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") - elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - runner_cls = FunctionCallAgentRunner - else: - raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}") - - runner = runner_cls( + runner = AgentAppRunner( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, conversation=conversation_result, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d30..0f3f9972c3 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -671,7 +671,7 @@ class WorkflowResponseConverter: task_id=task_id, data=AgentLogStreamResponse.Data( node_execution_id=event.node_execution_id, - id=event.id, + message_id=event.id, parent_id=event.parent_id, label=event.label, error=event.error, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 842ad545ad..09ac24a413 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, + ChunkType, MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, @@ -487,7 +488,17 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): if tts_publisher and queue_message: tts_publisher.publish(queue_message) - yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector) + yield self._text_chunk_to_stream_response( + text=delta_text, + from_variable_selector=event.from_variable_selector, + chunk_type=event.chunk_type, + tool_call_id=event.tool_call_id, + tool_name=event.tool_name, + tool_arguments=event.tool_arguments, + tool_files=event.tool_files, + tool_error=event.tool_error, + round_index=event.round_index, + ) def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: """Handle agent log events.""" @@ -650,16 +661,37 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): session.add(workflow_app_log) def _text_chunk_to_stream_response( - self, text: str, from_variable_selector: list[str] | None = None + self, + text: str, + from_variable_selector: list[str] | None = None, + chunk_type: ChunkType | None = None, + tool_call_id: str | None = None, + tool_name: str | None = None, + tool_arguments: str | None = None, + tool_files: list[str] | None = None, + tool_error: str | None = None, + round_index: int | None = None, ) -> TextChunkStreamResponse: """ Handle completed event. :param text: text :return: """ + from core.app.entities.task_entities import ChunkType as ResponseChunkType + response = TextChunkStreamResponse( task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), + data=TextChunkStreamResponse.Data( + text=text, + from_variable_selector=from_variable_selector, + chunk_type=ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + tool_files=tool_files or [], + tool_error=tool_error, + round_index=round_index, + ), ) return response diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 0e125b3538..3161956c9b 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -455,12 +455,21 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunStreamChunkEvent): + from core.app.entities.queue_entities import ChunkType as QueueChunkType + self._publish_event( QueueTextChunkEvent( text=event.chunk, from_variable_selector=list(event.selector), in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, + chunk_type=QueueChunkType(event.chunk_type.value), + tool_call_id=event.tool_call_id, + tool_name=event.tool_name, + tool_arguments=event.tool_arguments, + tool_files=event.tool_files, + tool_error=event.tool_error, + round_index=event.round_index, ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): diff --git a/api/core/app/entities/llm_generation_entities.py b/api/core/app/entities/llm_generation_entities.py new file mode 100644 index 0000000000..4e278249fe --- /dev/null +++ b/api/core/app/entities/llm_generation_entities.py @@ -0,0 +1,69 @@ +""" +LLM Generation Detail entities. + +Defines the structure for storing and transmitting LLM generation details +including reasoning content, tool calls, and their sequence. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class ContentSegment(BaseModel): + """Represents a content segment in the generation sequence.""" + + type: Literal["content"] = "content" + start: int = Field(..., description="Start position in the text") + end: int = Field(..., description="End position in the text") + + +class ReasoningSegment(BaseModel): + """Represents a reasoning segment in the generation sequence.""" + + type: Literal["reasoning"] = "reasoning" + index: int = Field(..., description="Index into reasoning_content array") + + +class ToolCallSegment(BaseModel): + """Represents a tool call segment in the generation sequence.""" + + type: Literal["tool_call"] = "tool_call" + index: int = Field(..., description="Index into tool_calls array") + + +SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment + + +class ToolCallDetail(BaseModel): + """Represents a tool call with its arguments and result.""" + + id: str = Field(default="", description="Unique identifier for the tool call") + name: str = Field(..., description="Name of the tool") + arguments: str = Field(default="", description="JSON string of tool arguments") + result: str = Field(default="", description="Result from the tool execution") + + +class LLMGenerationDetailData(BaseModel): + """ + Domain model for LLM generation detail. + + Contains the structured data for reasoning content, tool calls, + and their display sequence. + """ + + reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments") + tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details") + sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments") + + def is_empty(self) -> bool: + """Check if there's any meaningful generation detail.""" + return not self.reasoning_content and not self.tool_calls + + def to_response_dict(self) -> dict: + """Convert to dictionary for API response.""" + return { + "reasoning_content": self.reasoning_content, + "tool_calls": [tc.model_dump() for tc in self.tool_calls], + "sequence": [seg.model_dump() for seg in self.sequence], + } diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 77d6bf03b4..c767fcfc34 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -177,6 +177,15 @@ class QueueLoopCompletedEvent(AppQueueEvent): error: str | None = None +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity @@ -191,6 +200,28 @@ class QueueTextChunkEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" + # Extended fields for Agent/Tool streaming + chunk_type: ChunkType = ChunkType.TEXT + """type of the chunk""" + + # Tool call fields (when chunk_type == TOOL_CALL) + tool_call_id: str | None = None + """unique identifier for this tool call""" + tool_name: str | None = None + """name of the tool being called""" + tool_arguments: str | None = None + """accumulated tool arguments JSON""" + + # Tool result fields (when chunk_type == TOOL_RESULT) + tool_files: list[str] = Field(default_factory=list) + """file IDs produced by tool""" + tool_error: str | None = None + """error message if tool failed""" + + # Thought fields (when chunk_type == THOUGHT) + round_index: int | None = None + """current iteration round""" + class QueueAgentMessageEvent(AppQueueEvent): """ diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7692128985..4609cd87f6 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -116,6 +116,28 @@ class MessageStreamResponse(StreamResponse): answer: str from_variable_selector: list[str] | None = None + # Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import) + chunk_type: str | None = None + """type of the chunk: text, tool_call, tool_result, thought""" + + # Tool call fields (when chunk_type == "tool_call") + tool_call_id: str | None = None + """unique identifier for this tool call""" + tool_name: str | None = None + """name of the tool being called""" + tool_arguments: str | None = None + """accumulated tool arguments JSON""" + + # Tool result fields (when chunk_type == "tool_result") + tool_files: list[str] | None = None + """file IDs produced by tool""" + tool_error: str | None = None + """error message if tool failed""" + + # Thought fields (when chunk_type == "thought") + round_index: int | None = None + """current iteration round""" + class MessageAudioStreamResponse(StreamResponse): """ @@ -585,6 +607,15 @@ class LoopNodeCompletedStreamResponse(StreamResponse): data: Data +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class TextChunkStreamResponse(StreamResponse): """ TextChunkStreamResponse entity @@ -598,6 +629,28 @@ class TextChunkStreamResponse(StreamResponse): text: str from_variable_selector: list[str] | None = None + # Extended fields for Agent/Tool streaming + chunk_type: ChunkType = ChunkType.TEXT + """type of the chunk""" + + # Tool call fields (when chunk_type == TOOL_CALL) + tool_call_id: str | None = None + """unique identifier for this tool call""" + tool_name: str | None = None + """name of the tool being called""" + tool_arguments: str | None = None + """accumulated tool arguments JSON""" + + # Tool result fields (when chunk_type == TOOL_RESULT) + tool_files: list[str] = Field(default_factory=list) + """file IDs produced by tool""" + tool_error: str | None = None + """error message if tool failed""" + + # Thought fields (when chunk_type == THOUGHT) + round_index: int | None = None + """current iteration round""" + event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -746,7 +799,7 @@ class AgentLogStreamResponse(StreamResponse): """ node_execution_id: str - id: str + message_id: str label: str parent_id: str | None = None error: str | None = None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 98548ddfbb..2405413d71 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -58,7 +58,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.model import AppMode, Conversation, Message, MessageAgentThought +from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -425,11 +425,92 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) ) + # Save LLM generation detail if there's reasoning_content + self._save_generation_detail(session=session, message=message, llm_result=llm_result) + message_was_created.send( message, application_generate_entity=self._application_generate_entity, ) + def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None: + """ + Save LLM generation detail for Completion/Chat/Agent-Chat applications. + For Agent-Chat, also merges MessageAgentThought records. + """ + import json + + reasoning_list: list[str] = [] + tool_calls_list: list[dict] = [] + sequence: list[dict] = [] + answer = message.answer or "" + + # Check if this is Agent-Chat mode by looking for agent thoughts + agent_thoughts = ( + session.query(MessageAgentThought) + .filter_by(message_id=message.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) + + if agent_thoughts: + # Agent-Chat mode: merge MessageAgentThought records + content_pos = 0 + for thought in agent_thoughts: + # Add thought/reasoning + if thought.thought: + reasoning_list.append(thought.thought) + sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1}) + + # Add tool calls + if thought.tool: + tool_calls_list.append( + { + "name": thought.tool, + "arguments": thought.tool_input or "", + "result": thought.observation or "", + } + ) + sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1}) + + # Add answer content if present + if thought.answer: + start = content_pos + end = content_pos + len(thought.answer) + sequence.append({"type": "content", "start": start, "end": end}) + content_pos = end + else: + # Completion/Chat mode: use reasoning_content from llm_result + reasoning_content = llm_result.reasoning_content + if reasoning_content: + reasoning_list = [reasoning_content] + # Content comes first, then reasoning + if answer: + sequence.append({"type": "content", "start": 0, "end": len(answer)}) + sequence.append({"type": "reasoning", "index": 0}) + + # Only save if there's meaningful generation detail + if not reasoning_list and not tool_calls_list: + return + + # Check if generation detail already exists + existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first() + + if existing: + existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None + existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None + existing.sequence = json.dumps(sequence) if sequence else None + else: + generation_detail = LLMGenerationDetail( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + message_id=message.id, + reasoning_content=json.dumps(reasoning_list) if reasoning_list else None, + tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None, + sequence=json.dumps(sequence) if sequence else None, + ) + session.add(generation_detail) + def _handle_stop(self, event: QueueStopEvent): """ Handle stop. diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2e6f92efa5..414fed6701 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -214,12 +214,30 @@ class MessageCycleManager: return None def message_to_stream_response( - self, answer: str, message_id: str, from_variable_selector: list[str] | None = None + self, + answer: str, + message_id: str, + from_variable_selector: list[str] | None = None, + chunk_type: str | None = None, + tool_call_id: str | None = None, + tool_name: str | None = None, + tool_arguments: str | None = None, + tool_files: list[str] | None = None, + tool_error: str | None = None, + round_index: int | None = None, ) -> MessageStreamResponse: """ Message to stream response. :param answer: answer :param message_id: message id + :param from_variable_selector: from variable selector + :param chunk_type: type of the chunk (text, function_call, tool_result, thought) + :param tool_call_id: unique identifier for this tool call + :param tool_name: name of the tool being called + :param tool_arguments: accumulated tool arguments JSON + :param tool_files: file IDs produced by tool + :param tool_error: error message if tool failed + :param round_index: current iteration round :return: """ with Session(db.engine, expire_on_commit=False) as session: @@ -232,6 +250,13 @@ class MessageCycleManager: answer=answer, from_variable_selector=from_variable_selector, event=event_type, + chunk_type=chunk_type, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + tool_files=tool_files, + tool_error=tool_error, + round_index=round_index, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 4436773d25..79b0c702e0 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -29,6 +29,7 @@ from models import ( Account, CreatorUserRole, EndUser, + LLMGenerationDetail, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, ) @@ -457,6 +458,94 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) session.merge(db_model) session.flush() + # Save LLMGenerationDetail for LLM nodes with successful execution + if ( + domain_model.node_type == NodeType.LLM + and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED + and domain_model.outputs is not None + ): + self._save_llm_generation_detail(session, domain_model) + + def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None: + """ + Save LLM generation detail for LLM nodes. + Extracts reasoning_content, tool_calls, and sequence from outputs and metadata. + """ + outputs = execution.outputs or {} + metadata = execution.metadata or {} + + # Extract reasoning_content from outputs + reasoning_content = outputs.get("reasoning_content") + reasoning_list: list[str] = [] + if reasoning_content: + # reasoning_content could be a string or already a list + if isinstance(reasoning_content, str): + reasoning_list = [reasoning_content] if reasoning_content else [] + elif isinstance(reasoning_content, list): + reasoning_list = reasoning_content + + # Extract tool_calls from metadata.agent_log + tool_calls_list: list[dict] = [] + agent_log = metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG) + if agent_log and isinstance(agent_log, list): + for log in agent_log: + # Each log entry has label, data, status, etc. + log_data = log.data if hasattr(log, "data") else log.get("data", {}) + if log_data.get("tool_name"): + tool_calls_list.append( + { + "id": log_data.get("tool_call_id", ""), + "name": log_data.get("tool_name", ""), + "arguments": json.dumps(log_data.get("tool_args", {})), + "result": str(log_data.get("output", "")), + } + ) + + # Build sequence based on content, reasoning, and tool_calls + sequence: list[dict] = [] + text = outputs.get("text", "") + + # For now, use a simple sequence: content -> reasoning -> tool_calls + # This can be enhanced later to track actual streaming order + if text: + sequence.append({"type": "content", "start": 0, "end": len(text)}) + for i, _ in enumerate(reasoning_list): + sequence.append({"type": "reasoning", "index": i}) + for i in range(len(tool_calls_list)): + sequence.append({"type": "tool_call", "index": i}) + + # Only save if there's meaningful generation detail + if not reasoning_list and not tool_calls_list: + return + + # Check if generation detail already exists for this node execution + existing = ( + session.query(LLMGenerationDetail) + .filter_by( + workflow_run_id=execution.workflow_execution_id, + node_id=execution.node_id, + ) + .first() + ) + + if existing: + # Update existing record + existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None + existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None + existing.sequence = json.dumps(sequence) if sequence else None + else: + # Create new record + generation_detail = LLMGenerationDetail( + tenant_id=self._tenant_id, + app_id=self._app_id, + workflow_run_id=execution.workflow_execution_id, + node_id=execution.node_id, + reasoning_content=json.dumps(reasoning_list) if reasoning_list else None, + tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None, + sequence=json.dumps(sequence) if sequence else None, + ) + session.add(generation_detail) + def get_db_models_by_workflow_run( self, workflow_run_id: str, diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 8ca4eabb7a..cdbfd027ee 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from models.model import File +from core.model_runtime.entities.message_entities import PromptMessageTool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ToolEntity, @@ -152,6 +153,60 @@ class Tool(ABC): return parameters + def to_prompt_message_tool(self) -> PromptMessageTool: + message_tool = PromptMessageTool( + name=self.entity.identity.name, + description=self.entity.description.llm if self.entity.description else "", + parameters={ + "type": "object", + "properties": {}, + "required": [], + }, + ) + + parameters = self.get_merged_runtime_parameters() + for parameter in parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + # Determine the description based on parameter type + if parameter.type == ToolParameter.ToolParameterType.FILE: + file_format_desc = " Input the file id with format: [File: file_id]." + else: + file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. " + + message_tool.parameters["properties"][parameter.name] = { + "type": "string", + "description": (parameter.llm_description or "") + file_format_desc, + } + continue + enum = [] + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] if parameter.options else [] + + message_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) + + if len(enum) > 0: + message_tool.parameters["properties"][parameter.name]["enum"] = enum + + if parameter.required: + message_tool.parameters["required"].append(parameter.name) + + return message_tool + def create_image_message( self, image: str, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index cf12d5ec1f..3a60d34691 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -247,6 +247,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output DATASOURCE_INFO = "datasource_info" + LLM_CONTENT_SEQUENCE = "llm_content_sequence" class WorkflowNodeExecutionStatus(StrEnum): diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 98e0ea91ef..bd20c4f334 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -321,11 +321,20 @@ class ResponseStreamCoordinator: selector: Sequence[str], chunk: str, is_final: bool = False, + **extra_fields, ) -> NodeRunStreamChunkEvent: """Create a stream chunk event with consistent structure. For selectors with special prefixes (sys, env, conversation), we use the active response node's information since these are not actual node IDs. + + Args: + node_id: The node ID to attribute the event to + execution_id: The execution ID for this node + selector: The variable selector + chunk: The chunk content + is_final: Whether this is the final chunk + **extra_fields: Additional fields for specialized events (chunk_type, tool_call_id, etc.) """ # Check if this is a special selector that doesn't correspond to a node if selector and selector[0] not in self._graph.nodes and self._active_session: @@ -338,6 +347,7 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, + **extra_fields, ) # Standard case: selector refers to an actual node @@ -349,6 +359,7 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, + **extra_fields, ) def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: @@ -356,6 +367,8 @@ class ResponseStreamCoordinator: Handles both regular node selectors and special system selectors (sys, env, conversation). For special selectors, we attribute the output to the active response node. + + For object-type variables, automatically streams all child fields that have stream events. """ events: list[NodeRunStreamChunkEvent] = [] source_selector_prefix = segment.selector[0] if segment.selector else "" @@ -372,49 +385,93 @@ class ResponseStreamCoordinator: output_node_id = source_selector_prefix execution_id = self._get_or_create_execution_id(output_node_id) - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) + # Check if there's a direct stream for this selector + has_direct_stream = ( + tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams + ) - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True + if has_direct_stream: + # Stream all available chunks for direct stream + while self._has_unread_stream(segment.selector): + if event := self._pop_stream_chunk(segment.selector): + # For special selectors, update the event to use active response node's information + if self._active_session and source_selector_prefix not in self._graph.nodes: + response_node = self._graph.nodes[self._active_session.node_id] + updated_event = NodeRunStreamChunkEvent( + id=execution_id, + node_id=response_node.id, + node_type=response_node.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + ) + events.append(updated_event) + else: + events.append(event) - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, + # Check if stream is closed + if self._is_stream_closed(segment.selector): + is_complete = True + + else: + # No direct stream - check for child field streams (for object types) + child_streams = self._find_child_streams(segment.selector) + + if child_streams: + # Process all child streams + all_children_complete = True + + for child_selector in sorted(child_streams): + # Stream all available chunks from this child + while self._has_unread_stream(child_selector): + if event := self._pop_stream_chunk(child_selector): + # Forward child stream event + if self._active_session and source_selector_prefix not in self._graph.nodes: + response_node = self._graph.nodes[self._active_session.node_id] + updated_event = NodeRunStreamChunkEvent( + id=execution_id, + node_id=response_node.id, + node_type=response_node.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=event.chunk_type, + tool_call_id=event.tool_call_id, + tool_name=event.tool_name, + tool_arguments=event.tool_arguments, + tool_files=event.tool_files, + tool_error=event.tool_error, + round_index=event.round_index, + ) + events.append(updated_event) + else: + events.append(event) + + # Check if this child stream is complete + if not self._is_stream_closed(child_selector): + all_children_complete = False + + # Object segment is complete only when all children are complete + is_complete = all_children_complete + + # Fallback: check if scalar value exists in variable pool + if not is_complete and not has_direct_stream: + if value := self._variable_pool.get(segment.selector): + # Process scalar value + is_last_segment = bool( + self._active_session + and self._active_session.index == len(self._active_session.template.segments) - 1 ) - ) - is_complete = True + events.append( + self._create_stream_chunk_event( + node_id=output_node_id, + execution_id=execution_id, + selector=segment.selector, + chunk=value.markdown, + is_final=is_last_segment, + ) + ) + is_complete = True return events, is_complete @@ -513,6 +570,36 @@ class ResponseStreamCoordinator: # ============= Internal Stream Management Methods ============= + def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]: + """Find all child stream selectors that are descendants of the parent selector. + + For example, if parent_selector is ['llm', 'generation'], this will find: + - ['llm', 'generation', 'content'] + - ['llm', 'generation', 'tool_calls'] + - ['llm', 'generation', 'tool_results'] + - ['llm', 'generation', 'thought'] + + Args: + parent_selector: The parent selector to search for children + + Returns: + List of child selector tuples found in stream buffers or closed streams + """ + parent_key = tuple(parent_selector) + parent_len = len(parent_key) + child_streams: set[tuple[str, ...]] = set() + + # Search in both active buffers and closed streams + all_selectors = set(self._stream_buffers.keys()) | self._closed_streams + + for selector_key in all_selectors: + # Check if this selector is a direct child of the parent + # Direct child means: len(child) == len(parent) + 1 and child starts with parent + if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key: + child_streams.add(selector_key) + + return sorted(child_streams) + def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: """ Append a stream chunk to the internal buffer. diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 7a5edbb331..6c37fa1bc6 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -36,6 +36,7 @@ from .loop import ( # Node events from .node import ( + ChunkType, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunPauseRequestedEvent, @@ -48,6 +49,7 @@ from .node import ( __all__ = [ "BaseGraphEvent", + "ChunkType", "GraphEngineEvent", "GraphNodeEventBase", "GraphRunAbortedEvent", diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index f225798d41..c7f76c424d 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -1,5 +1,6 @@ from collections.abc import Sequence from datetime import datetime +from enum import StrEnum from pydantic import Field @@ -21,13 +22,37 @@ class NodeRunStartedEvent(GraphNodeEventBase): provider_id: str = "" +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields + """Stream chunk event for workflow node execution.""" + + # Base fields selector: Sequence[str] = Field( ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" ) chunk: str = Field(..., description="the actual chunk content") is_final: bool = Field(default=False, description="indicates if this is the last chunk") + chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") + + # Tool call fields (when chunk_type == TOOL_CALL) + tool_call_id: str | None = Field(default=None, description="unique identifier for this tool call") + tool_name: str | None = Field(default=None, description="name of the tool being called") + tool_arguments: str | None = Field(default=None, description="accumulated tool arguments JSON") + + # Tool result fields (when chunk_type == TOOL_RESULT) + tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool") + tool_error: str | None = Field(default=None, description="error message if tool failed") + + # Thought fields (when chunk_type == THOUGHT) + round_index: int | None = Field(default=None, description="current iteration round") class NodeRunRetrieverResourceEvent(GraphNodeEventBase): diff --git a/api/core/workflow/node_events/__init__.py b/api/core/workflow/node_events/__init__.py index f14a594c85..67263311b9 100644 --- a/api/core/workflow/node_events/__init__.py +++ b/api/core/workflow/node_events/__init__.py @@ -13,16 +13,21 @@ from .loop import ( LoopSucceededEvent, ) from .node import ( + ChunkType, ModelInvokeCompletedEvent, PauseRequestedEvent, RunRetrieverResourceEvent, RunRetryEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) __all__ = [ "AgentLogEvent", + "ChunkType", "IterationFailedEvent", "IterationNextEvent", "IterationStartedEvent", @@ -39,4 +44,7 @@ __all__ = [ "RunRetryEvent", "StreamChunkEvent", "StreamCompletedEvent", + "ThoughtChunkEvent", + "ToolCallChunkEvent", + "ToolResultChunkEvent", ] diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index ebf93f2fc2..3a062b9c4c 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -1,5 +1,6 @@ from collections.abc import Sequence from datetime import datetime +from enum import StrEnum from pydantic import Field @@ -30,13 +31,50 @@ class RunRetryEvent(NodeEventBase): start_at: datetime = Field(..., description="Retry start time") +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields + """Base stream chunk event - normal text streaming output.""" + selector: Sequence[str] = Field( ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" ) chunk: str = Field(..., description="the actual chunk content") is_final: bool = Field(default=False, description="indicates if this is the last chunk") + chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") + + +class ToolCallChunkEvent(StreamChunkEvent): + """Tool call streaming event - tool call arguments streaming output.""" + + chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True) + tool_call_id: str = Field(..., description="unique identifier for this tool call") + tool_name: str = Field(..., description="name of the tool being called") + tool_arguments: str = Field(default="", description="accumulated tool arguments JSON") + + +class ToolResultChunkEvent(StreamChunkEvent): + """Tool result event - tool execution result.""" + + chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True) + tool_call_id: str = Field(..., description="identifier of the tool call this result belongs to") + tool_name: str = Field(..., description="name of the tool") + tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool") + tool_error: str | None = Field(default=None, description="error message if tool failed") + + +class ThoughtChunkEvent(StreamChunkEvent): + """Agent thought streaming event - Agent thinking process (ReAct).""" + + chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True) + round_index: int = Field(default=1, description="current iteration round") class StreamCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index c2e1105971..9be16d4f08 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -46,6 +46,9 @@ from core.workflow.node_events import ( RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) from core.workflow.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now @@ -536,6 +539,8 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + return NodeRunStreamChunkEvent( id=self._node_execution_id, node_id=self._node_id, @@ -543,6 +548,57 @@ class Node(Generic[NodeDataT]): selector=event.selector, chunk=event.chunk, is_final=event.is_final, + chunk_type=ChunkType(event.chunk_type.value), + ) + + @_dispatch.register + def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.TOOL_CALL, + tool_call_id=event.tool_call_id, + tool_name=event.tool_name, + tool_arguments=event.tool_arguments, + ) + + @_dispatch.register + def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.TOOL_RESULT, + tool_call_id=event.tool_call_id, + tool_name=event.tool_name, + tool_files=event.tool_files, + tool_error=event.tool_error, + ) + + @_dispatch.register + def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.THOUGHT, + round_index=event.round_index, ) @_dispatch.register diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py index f7bc713f63..edd0d3d581 100644 --- a/api/core/workflow/nodes/llm/__init__.py +++ b/api/core/workflow/nodes/llm/__init__.py @@ -3,6 +3,7 @@ from .entities import ( LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + ToolMetadata, VisionConfig, ) from .node import LLMNode @@ -13,5 +14,6 @@ __all__ = [ "LLMNodeCompletionModelPromptTemplate", "LLMNodeData", "ModelConfig", + "ToolMetadata", "VisionConfig", ] diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index fe6f2290aa..fbdd1daec7 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.entities import ImagePromptMessageContent, LLMMode from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.tools.entities.tool_entities import ToolProviderType from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector @@ -58,6 +59,30 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): jinja2_text: str | None = None +class ToolMetadata(BaseModel): + """ + Tool metadata for LLM node with tool support. + + Defines the essential fields needed for tool configuration, + particularly the 'type' field to identify tool provider type. + """ + + # Core fields + enabled: bool = True + type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow") + provider_name: str = Field(..., description="Tool provider name/identifier") + tool_name: str = Field(..., description="Tool name") + + # Optional fields + plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools") + credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication") + + # Configuration fields + parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters") + settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration") + extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description") + + class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate @@ -86,6 +111,9 @@ class LLMNodeData(BaseNodeData): ), ) + # Tool support (from Agent V2) + tools: Sequence[ToolMetadata] = Field(default_factory=list) + @field_validator("prompt_config", mode="before") @classmethod def convert_none_prompt_config(cls, v: Any): diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 0c545469bc..e9c363851f 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,3 +1,4 @@ +import re from collections.abc import Sequence from typing import cast @@ -154,3 +155,94 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs ) session.execute(stmt) session.commit() + + +class ThinkTagStreamParser: + """Lightweight state machine to split streaming chunks by tags.""" + + _START_PATTERN = re.compile(r"]*)?>", re.IGNORECASE) + _END_PATTERN = re.compile(r"", re.IGNORECASE) + _START_PREFIX = " int: + """Return length of the longest suffix of `text` that is a prefix of `prefix`.""" + max_len = min(len(text), len(prefix) - 1) + for i in range(max_len, 0, -1): + if text[-i:].lower() == prefix[:i].lower(): + return i + return 0 + + def process(self, chunk: str) -> list[tuple[str, str]]: + """ + Split incoming chunk into ('thought' | 'text', content) tuples. + Content excludes the tags themselves and handles split tags across chunks. + """ + parts: list[tuple[str, str]] = [] + self._buffer += chunk + + while self._buffer: + if self._in_think: + end_match = self._END_PATTERN.search(self._buffer) + if end_match: + thought_text = self._buffer[: end_match.start()] + if thought_text: + parts.append(("thought", thought_text)) + self._buffer = self._buffer[end_match.end() :] + self._in_think = False + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("thought", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + start_match = self._START_PATTERN.search(self._buffer) + if start_match: + prefix = self._buffer[: start_match.start()] + if prefix: + parts.append(("text", prefix)) + self._buffer = self._buffer[start_match.end() :] + self._in_think = True + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("text", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + cleaned_parts: list[tuple[str, str]] = [] + for kind, content in parts: + # Extra safeguard: strip any stray tags that slipped through. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + if content: + cleaned_parts.append((kind, content)) + + return cleaned_parts + + def flush(self) -> list[tuple[str, str]]: + """Flush remaining buffer when the stream ends.""" + if not self._buffer: + return [] + kind = "thought" if self._in_think else "text" + content = self._buffer + # Drop dangling partial tags instead of emitting them + if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX): + content = "" + self._buffer = "" + if not content: + return [] + # Strip any complete tags that might still be present. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + return [(kind, content)] if content else [] diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1a2473e0bb..bf41f476fd 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,6 +7,8 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from core.agent.entities import AgentLog, AgentResult, AgentToolEntity, ExecutionContext +from core.agent.patterns import StrategyFactory from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage @@ -44,6 +46,8 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.__base.tool import Tool +from core.tools.tool_manager import ToolManager from core.variables import ( ArrayFileSegment, ArraySegment, @@ -61,12 +65,16 @@ from core.workflow.enums import ( WorkflowNodeExecutionStatus, ) from core.workflow.node_events import ( + AgentLogEvent, ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node @@ -147,7 +155,8 @@ class LLMNode(Node[LLMNodeData]): clean_text = "" usage = LLMUsage.empty_usage() finish_reason = None - reasoning_content = None + reasoning_content = "" # Initialize as empty string for consistency + clean_text = "" # Initialize clean_text to avoid UnboundLocalError variable_pool = self.graph_runtime_state.variable_pool try: @@ -163,6 +172,15 @@ class LLMNode(Node[LLMNodeData]): # merge inputs inputs.update(jinja_inputs) + # Add all inputs to node_inputs for logging + node_inputs.update(inputs) + + # Add tools to inputs if configured + if self.tool_call_enabled: + node_inputs["tools"] = [ + {"provider_id": tool.provider_name, "tool_name": tool.tool_name} for tool in self._node_data.tools + ] + # fetch files files = ( llm_utils.fetch_files( @@ -222,21 +240,39 @@ class LLMNode(Node[LLMNodeData]): tenant_id=self.tenant_id, ) - # handle invoke result - generator = LLMNode.invoke_llm( - node_data_model=self.node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.user_id, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) + # Check if tools are configured + if self.tool_call_enabled: + # Use tool-enabled invocation (Agent V2 style) + # This generator handles all events including final events + generator = self._invoke_llm_with_tools( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + files=files, + variable_pool=variable_pool, + node_inputs=node_inputs, + process_data=process_data, + ) + # Forward all events and return early since _invoke_llm_with_tools + # already sends final event and StreamCompletedEvent + yield from generator + return + else: + # Use traditional LLM invocation + generator = LLMNode.invoke_llm( + node_data_model=self._node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self._node_id, + node_type=self.node_type, + reasoning_format=self._node_data.reasoning_format, + ) structured_output: LLMStructuredOutput | None = None @@ -287,6 +323,11 @@ class LLMNode(Node[LLMNodeData]): "reasoning_content": reasoning_content, "usage": jsonable_encoder(usage), "finish_reason": finish_reason, + "generation": { + "content": clean_text, + "reasoning_content": [reasoning_content] if reasoning_content else [], + "tool_calls": [], + }, } if structured_output: outputs["structured_output"] = structured_output.structured_output @@ -1204,6 +1245,398 @@ class LLMNode(Node[LLMNodeData]): def retry(self) -> bool: return self.node_data.retry_config.retry_enabled + @property + def tool_call_enabled(self) -> bool: + return ( + self.node_data.tools is not None + and len(self.node_data.tools) > 0 + and all(tool.enabled for tool in self.node_data.tools) + ) + + def _invoke_llm_with_tools( + self, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + files: Sequence["File"], + variable_pool: VariablePool, + node_inputs: dict[str, Any], + process_data: dict[str, Any], + ) -> Generator[NodeEventBase, None, None]: + """Invoke LLM with tools support (from Agent V2).""" + # Get model features to determine strategy + model_features = self._get_model_features(model_instance) + + # Prepare tool instances + tool_instances = self._prepare_tool_instances(variable_pool) + + # Prepare prompt files (files that come from prompt variables, not vision files) + prompt_files = self._extract_prompt_files(variable_pool) + + # Use factory to create appropriate strategy + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=model_instance, + tools=tool_instances, + files=prompt_files, + max_iterations=10, + context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), + ) + + # Run strategy + outputs = strategy.run( + prompt_messages=list(prompt_messages), + model_parameters=self._node_data.model.completion_params, + stop=list(stop or []), + stream=True, + ) + + # Process outputs + yield from self._process_tool_outputs(outputs, strategy, node_inputs, process_data) + + def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: + """Get model schema to determine features.""" + try: + model_type_instance = model_instance.model_type_instance + model_schema = model_type_instance.get_model_schema( + model_instance.model, + model_instance.credentials, + ) + return model_schema.features if model_schema and model_schema.features else [] + except Exception: + logger.warning("Failed to get model schema, assuming no special features") + return [] + + def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]: + """Prepare tool instances from configuration.""" + tool_instances = [] + + if self._node_data.tools: + for tool in self._node_data.tools: + try: + # Process settings to extract the correct structure + processed_settings = {} + for key, value in tool.settings.items(): + if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict): + # Extract the nested value if it has the ToolInput structure + if "type" in value["value"] and "value" in value["value"]: + processed_settings[key] = value["value"] + else: + processed_settings[key] = value + else: + processed_settings[key] = value + + # Merge parameters with processed settings (similar to Agent Node logic) + merged_parameters = {**tool.parameters, **processed_settings} + + # Create AgentToolEntity from ToolMetadata + agent_tool = AgentToolEntity( + provider_id=tool.provider_name, + provider_type=tool.type, + tool_name=tool.tool_name, + tool_parameters=merged_parameters, + plugin_unique_identifier=tool.plugin_unique_identifier, + credential_id=tool.credential_id, + ) + + # Get tool runtime from ToolManager + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=self.tenant_id, + app_id=self.app_id, + agent_tool=agent_tool, + invoke_from=self.invoke_from, + variable_pool=variable_pool, + ) + + # Apply custom description from extra field if available + if tool.extra.get("description") and tool_runtime.entity.description: + tool_runtime.entity.description.llm = ( + tool.extra.get("description") or tool_runtime.entity.description.llm + ) + + tool_instances.append(tool_runtime) + except Exception as e: + logger.warning("Failed to load tool %s: %s", tool, str(e)) + continue + + return tool_instances + + def _extract_prompt_files(self, variable_pool: VariablePool) -> list["File"]: + """Extract files from prompt template variables.""" + from core.variables import ArrayFileVariable, FileVariable + + files: list[File] = [] + + # Extract variables from prompt template + if isinstance(self._node_data.prompt_template, list): + for message in self._node_data.prompt_template: + if message.text: + parser = VariableTemplateParser(message.text) + variable_selectors = parser.extract_variable_selectors() + + for variable_selector in variable_selectors: + variable = variable_pool.get(variable_selector.value_selector) + if isinstance(variable, FileVariable) and variable.value: + files.append(variable.value) + elif isinstance(variable, ArrayFileVariable) and variable.value: + files.extend(variable.value) + + return files + + def _process_tool_outputs( + self, + outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult], + strategy: Any, + node_inputs: dict[str, Any], + process_data: dict[str, Any], + ) -> Generator[NodeEventBase, None, None]: + """Process strategy outputs and convert to node events.""" + text = "" + files: list[File] = [] + usage = LLMUsage.empty_usage() + agent_logs: list[AgentLogEvent] = [] + finish_reason = None + agent_result: AgentResult | None = None + + # Track current round for ThoughtChunkEvent + current_round = 1 + think_parser = llm_utils.ThinkTagStreamParser() + reasoning_chunks: list[str] = [] + + # Process each output from strategy + try: + for output in outputs: + if isinstance(output, AgentLog): + # Store agent log event for metadata (no longer yielded, StreamChunkEvent contains the info) + agent_log_event = AgentLogEvent( + message_id=output.id, + label=output.label, + node_execution_id=self.id, + parent_id=output.parent_id, + error=output.error, + status=output.status.value, + data=output.data, + metadata={k.value: v for k, v in output.metadata.items()}, + node_id=self._node_id, + ) + for log in agent_logs: + if log.message_id == agent_log_event.message_id: + # update the log + log.data = agent_log_event.data + log.status = agent_log_event.status + log.error = agent_log_event.error + log.label = agent_log_event.label + log.metadata = agent_log_event.metadata + break + else: + agent_logs.append(agent_log_event) + + # Extract round number from ROUND log type + if output.log_type == AgentLog.LogType.ROUND: + round_index = output.data.get("round_index") + if isinstance(round_index, int): + current_round = round_index + + # Emit tool call events when tool call starts + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START: + tool_name = output.data.get("tool_name", "") + tool_call_id = output.data.get("tool_call_id", "") + tool_args = output.data.get("tool_args", {}) + tool_arguments = json.dumps(tool_args) if tool_args else "" + + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk=tool_arguments, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + is_final=True, + ) + + # Emit tool result events when tool call completes + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.SUCCESS: + tool_name = output.data.get("tool_name", "") + tool_output = output.data.get("output", "") + tool_call_id = output.data.get("tool_call_id", "") + tool_files = [] + tool_error = None + + # Extract file IDs if present + files_data = output.data.get("files") + if files_data and isinstance(files_data, list): + tool_files = files_data + + # Check for error in meta + meta = output.data.get("meta") + if meta and isinstance(meta, dict) and meta.get("error"): + tool_error = meta.get("error") + + yield ToolResultChunkEvent( + selector=[self._node_id, "generation", "tool_results"], + chunk=str(tool_output) if tool_output else "", + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_files=tool_files, + tool_error=tool_error, + is_final=True, + ) + + elif isinstance(output, LLMResultChunk): + # Handle LLM result chunks - only process text content + message = output.delta.message + + # Handle text content + if message and message.content: + chunk_text = message.content + if isinstance(chunk_text, list): + # Extract text from content list + chunk_text = "".join(getattr(c, "data", str(c)) for c in chunk_text) + else: + chunk_text = str(chunk_text) + for kind, segment in think_parser.process(chunk_text): + if not segment: + continue + + if kind == "thought": + reasoning_chunks.append(segment) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk=segment, + round_index=current_round, + is_final=False, + ) + else: + text += segment + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=segment, + is_final=False, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + + if output.delta.usage: + self._accumulate_usage(usage, output.delta.usage) + + # Capture finish reason + if output.delta.finish_reason: + finish_reason = output.delta.finish_reason + + except StopIteration as e: + # Get the return value from generator + if isinstance(getattr(e, "value", None), AgentResult): + agent_result = e.value + + # Use result from generator if available + if agent_result: + text = agent_result.text or text + files = agent_result.files + if agent_result.usage: + usage = agent_result.usage + if agent_result.finish_reason: + finish_reason = agent_result.finish_reason + + # Flush any remaining buffered content after streaming ends + for kind, segment in think_parser.flush(): + if not segment: + continue + if kind == "thought": + reasoning_chunks.append(segment) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk=segment, + round_index=current_round, + is_final=False, + ) + else: + text += segment + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=segment, + is_final=False, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + + # Send final events for all streams + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + + # Close generation sub-field streams + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk="", + is_final=True, + ) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + round_index=current_round, + is_final=True, + ) + + # Build generation field from agent_logs + tool_calls_for_generation = [] + for log in agent_logs: + if log.label == "Tool Call": + tool_call_data = { + "id": log.data.get("tool_call_id", ""), + "name": log.data.get("tool_name", ""), + "arguments": json.dumps(log.data.get("tool_args", {})), + "result": log.data.get("output", ""), + } + tool_calls_for_generation.append(tool_call_data) + + # Complete with results + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "text": text, + "files": ArrayFileSegment(value=files), + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, + "generation": { + "reasoning_content": ["".join(reasoning_chunks)] if reasoning_chunks else [], + "tool_calls": tool_calls_for_generation, + "content": text, + }, + }, + metadata={ + WorkflowNodeExecutionMetadataKey.LLM_CONTENT_SEQUENCE: [], + }, + inputs={ + **node_inputs, + "tools": [ + {"provider_id": tool.provider_name, "tool_name": tool.tool_name} + for tool in self._node_data.tools + ] + if self._node_data.tools + else [], + }, + process_data=process_data, + llm_usage=usage, + ) + ) + + def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None: + """Accumulate LLM usage statistics.""" + total_usage.prompt_tokens += delta_usage.prompt_tokens + total_usage.completion_tokens += delta_usage.completion_tokens + total_usage.total_tokens += delta_usage.total_tokens + total_usage.prompt_price += delta_usage.prompt_price + total_usage.completion_price += delta_usage.completion_price + total_usage.total_price += delta_usage.total_price + def _combine_message_content_with_role( *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index ecc267cf38..d5b2574edc 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -89,6 +89,7 @@ message_detail_fields = { "status": fields.String, "error": fields.String, "parent_message_id": fields.String, + "generation_detail": fields.Raw, } feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index a419da2e18..8b9bcac76f 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -68,6 +68,7 @@ message_fields = { "message_files": fields.List(fields.Nested(message_file_fields)), "status": fields.String, "error": fields.String, + "generation_detail": fields.Raw, } message_infinite_scroll_pagination_fields = { diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 821ce62ecc..6305d8d9d5 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -129,6 +129,7 @@ workflow_run_node_execution_fields = { "inputs_truncated": fields.Boolean, "outputs_truncated": fields.Boolean, "process_data_truncated": fields.Boolean, + "generation_detail": fields.Raw, } workflow_run_node_execution_list_fields = { diff --git a/api/models/__init__.py b/api/models/__init__.py index 906bc3198e..bc29421d4c 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -49,6 +49,7 @@ from .model import ( EndUser, IconType, InstalledApp, + LLMGenerationDetail, Message, MessageAgentThought, MessageAnnotation, @@ -154,6 +155,7 @@ __all__ = [ "IconType", "InstalledApp", "InvitationCode", + "LLMGenerationDetail", "LoadBalancingModelConfig", "Message", "MessageAgentThought", diff --git a/api/models/model.py b/api/models/model.py index 1731ff5699..0e862ad845 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -31,6 +31,8 @@ from .provider_ids import GenericProviderID from .types import LongText, StringUUID if TYPE_CHECKING: + from core.app.entities.llm_generation_entities import LLMGenerationDetailData + from .workflow import Workflow @@ -1169,6 +1171,17 @@ class Message(Base): .all() ) + @property + def generation_detail(self) -> dict[str, Any] | None: + """ + Get LLM generation detail for this message. + Returns the detail as a dictionary or None if not found. + """ + detail = db.session.query(LLMGenerationDetail).filter_by(message_id=self.id).first() + if detail: + return detail.to_dict() + return None + @property def retriever_resources(self) -> Any: return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @@ -2041,3 +2054,87 @@ class TraceAppConfig(TypeBase): "created_at": str(self.created_at) if self.created_at else None, "updated_at": str(self.updated_at) if self.updated_at else None, } + + +class LLMGenerationDetail(Base): + """ + Store LLM generation details including reasoning process and tool calls. + + Association (choose one): + - For apps with Message: use message_id (one-to-one) + - For Workflow: use workflow_run_id + node_id (one run may have multiple LLM nodes) + """ + + __tablename__ = "llm_generation_details" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="llm_generation_detail_pkey"), + sa.Index("idx_llm_generation_detail_message", "message_id"), + sa.Index("idx_llm_generation_detail_workflow", "workflow_run_id", "node_id"), + sa.CheckConstraint( + "(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL)" + " OR " + "(message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)", + name="ck_llm_generation_detail_assoc_mode", + ), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Association fields (choose one) + message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, unique=True) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + node_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + + # Core data as JSON strings + reasoning_content: Mapped[str | None] = mapped_column(LongText) + tool_calls: Mapped[str | None] = mapped_column(LongText) + sequence: Mapped[str | None] = mapped_column(LongText) + + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + def to_domain_model(self) -> "LLMGenerationDetailData": + """Convert to Pydantic domain model with proper validation.""" + from core.app.entities.llm_generation_entities import LLMGenerationDetailData + + return LLMGenerationDetailData( + reasoning_content=json.loads(self.reasoning_content) if self.reasoning_content else [], + tool_calls=json.loads(self.tool_calls) if self.tool_calls else [], + sequence=json.loads(self.sequence) if self.sequence else [], + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return self.to_domain_model().to_response_dict() + + @classmethod + def from_domain_model( + cls, + data: "LLMGenerationDetailData", + *, + tenant_id: str, + app_id: str, + message_id: str | None = None, + workflow_run_id: str | None = None, + node_id: str | None = None, + ) -> "LLMGenerationDetail": + """Create from Pydantic domain model.""" + # Enforce association mode at object creation time as well. + message_mode = message_id is not None + workflow_mode = workflow_run_id is not None or node_id is not None + if message_mode and workflow_mode: + raise ValueError("LLMGenerationDetail cannot set both message_id and workflow_run_id/node_id.") + if not message_mode and not (workflow_run_id and node_id): + raise ValueError("LLMGenerationDetail requires either message_id or workflow_run_id+node_id.") + + return cls( + tenant_id=tenant_id, + app_id=app_id, + message_id=message_id, + workflow_run_id=workflow_run_id, + node_id=node_id, + reasoning_content=json.dumps(data.reasoning_content) if data.reasoning_content else None, + tool_calls=json.dumps([tc.model_dump() for tc in data.tool_calls]) if data.tool_calls else None, + sequence=json.dumps([seg.model_dump() for seg in data.sequence]) if data.sequence else None, + ) diff --git a/api/services/llm_generation_service.py b/api/services/llm_generation_service.py new file mode 100644 index 0000000000..1e8c78a416 --- /dev/null +++ b/api/services/llm_generation_service.py @@ -0,0 +1,131 @@ +""" +LLM Generation Detail Service. + +Provides methods to query and attach generation details to workflow node executions +and messages, avoiding N+1 query problems. +""" + +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.llm_generation_entities import LLMGenerationDetailData +from models import LLMGenerationDetail, WorkflowNodeExecutionModel + + +class LLMGenerationService: + """Service for handling LLM generation details.""" + + def __init__(self, session: Session): + self._session = session + + def get_generation_details_for_workflow_run( + self, + workflow_run_id: str, + *, + tenant_id: str | None = None, + app_id: str | None = None, + ) -> dict[str, LLMGenerationDetailData]: + """ + Batch query generation details for all LLM nodes in a workflow run. + + Returns dict mapping node_id to LLMGenerationDetailData. + """ + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.workflow_run_id == workflow_run_id) + if tenant_id: + stmt = stmt.where(LLMGenerationDetail.tenant_id == tenant_id) + if app_id: + stmt = stmt.where(LLMGenerationDetail.app_id == app_id) + details = self._session.scalars(stmt).all() + return {detail.node_id: detail.to_domain_model() for detail in details if detail.node_id} + + def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None: + """Query generation detail for a specific message.""" + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id) + detail = self._session.scalars(stmt).first() + return detail.to_domain_model() if detail else None + + def get_generation_details_for_messages( + self, + message_ids: list[str], + ) -> dict[str, LLMGenerationDetailData]: + """Batch query generation details for multiple messages.""" + if not message_ids: + return {} + + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids)) + details = self._session.scalars(stmt).all() + return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id} + + def attach_generation_details_to_node_executions( + self, + node_executions: Sequence[WorkflowNodeExecutionModel], + workflow_run_id: str, + *, + tenant_id: str | None = None, + app_id: str | None = None, + ) -> list[dict]: + """ + Attach generation details to node executions and return as dicts. + + Queries generation details in batch and attaches them to the corresponding + node executions, avoiding N+1 queries. + """ + generation_details = self.get_generation_details_for_workflow_run( + workflow_run_id, tenant_id=tenant_id, app_id=app_id + ) + + return [ + { + "id": node.id, + "index": node.index, + "predecessor_node_id": node.predecessor_node_id, + "node_id": node.node_id, + "node_type": node.node_type, + "title": node.title, + "inputs": node.inputs_dict, + "process_data": node.process_data_dict, + "outputs": node.outputs_dict, + "status": node.status, + "error": node.error, + "elapsed_time": node.elapsed_time, + "execution_metadata": node.execution_metadata_dict, + "extras": node.extras, + "created_at": int(node.created_at.timestamp()) if node.created_at else None, + "created_by_role": node.created_by_role, + "created_by_account": _serialize_account(node.created_by_account), + "created_by_end_user": _serialize_end_user(node.created_by_end_user), + "finished_at": int(node.finished_at.timestamp()) if node.finished_at else None, + "inputs_truncated": node.inputs_truncated, + "outputs_truncated": node.outputs_truncated, + "process_data_truncated": node.process_data_truncated, + "generation_detail": generation_details[node.node_id].to_response_dict() + if node.node_id in generation_details + else None, + } + for node in node_executions + ] + + +def _serialize_account(account) -> dict | None: + """Serialize Account to dict for API response.""" + if not account: + return None + return { + "id": account.id, + "name": account.name, + "email": account.email, + } + + +def _serialize_end_user(end_user) -> dict | None: + """Serialize EndUser to dict for API response.""" + if not end_user: + return None + return { + "id": end_user.id, + "type": end_user.type, + "is_anonymous": end_user.is_anonymous, + "session_id": end_user.session_id, + } diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index b903d8df5f..14bcca8754 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,8 +1,8 @@ import threading -from collections.abc import Sequence +from typing import Any from sqlalchemy import Engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker import contexts from extensions.ext_database import db @@ -11,12 +11,12 @@ from models import ( Account, App, EndUser, - WorkflowNodeExecutionModel, WorkflowRun, WorkflowRunTriggeredFrom, ) from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +from services.llm_generation_service import LLMGenerationService class WorkflowRunService: @@ -137,9 +137,9 @@ class WorkflowRunService: app_model: App, run_id: str, user: Account | EndUser, - ) -> Sequence[WorkflowNodeExecutionModel]: + ) -> list[dict[str, Any]]: """ - Get workflow run node execution list + Get workflow run node execution list with generation details attached. """ workflow_run = self.get_workflow_run(app_model, run_id) @@ -154,8 +154,18 @@ class WorkflowRunService: if tenant_id is None: raise ValueError("User tenant_id cannot be None") - return self._node_execution_service_repo.get_executions_by_workflow_run( + node_executions = self._node_execution_service_repo.get_executions_by_workflow_run( tenant_id=tenant_id, app_id=app_model.id, workflow_run_id=run_id, ) + + # Attach generation details using batch query + with Session(db.engine) as session: + generation_service = LLMGenerationService(session) + return generation_service.attach_generation_details_to_node_executions( + node_executions=node_executions, + workflow_run_id=run_id, + tenant_id=tenant_id, + app_id=app_model.id, + ) diff --git a/api/tests/unit_tests/core/agent/patterns/__init__.py b/api/tests/unit_tests/core/agent/patterns/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/agent/patterns/test_base.py b/api/tests/unit_tests/core/agent/patterns/test_base.py new file mode 100644 index 0000000000..b0e0d44940 --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_base.py @@ -0,0 +1,324 @@ +"""Tests for AgentPattern base class.""" + +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentLog, ExecutionContext +from core.agent.patterns.base import AgentPattern +from core.model_runtime.entities.llm_entities import LLMUsage + + +class ConcreteAgentPattern(AgentPattern): + """Concrete implementation of AgentPattern for testing.""" + + def run(self, prompt_messages, model_parameters, stop=[], stream=True): + """Minimal implementation for testing.""" + yield from [] + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def agent_pattern(mock_model_instance, mock_context): + """Create a concrete agent pattern for testing.""" + return ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=10, + ) + + +class TestAccumulateUsage: + """Tests for _accumulate_usage method.""" + + def test_accumulate_usage_to_empty_dict(self, agent_pattern): + """Test accumulating usage to an empty dict creates a copy.""" + total_usage: dict = {"usage": None} + delta_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + agent_pattern._accumulate_usage(total_usage, delta_usage) + + assert total_usage["usage"] is not None + assert total_usage["usage"].total_tokens == 150 + assert total_usage["usage"].prompt_tokens == 100 + assert total_usage["usage"].completion_tokens == 50 + # Verify it's a copy, not a reference + assert total_usage["usage"] is not delta_usage + + def test_accumulate_usage_adds_to_existing(self, agent_pattern): + """Test accumulating usage adds to existing values.""" + initial_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + total_usage: dict = {"usage": initial_usage} + + delta_usage = LLMUsage( + prompt_tokens=200, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.2"), + completion_tokens=100, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.2"), + total_tokens=300, + total_price=Decimal("0.4"), + currency="USD", + latency=0.5, + ) + + agent_pattern._accumulate_usage(total_usage, delta_usage) + + assert total_usage["usage"].total_tokens == 450 # 150 + 300 + assert total_usage["usage"].prompt_tokens == 300 # 100 + 200 + assert total_usage["usage"].completion_tokens == 150 # 50 + 100 + + def test_accumulate_usage_multiple_rounds(self, agent_pattern): + """Test accumulating usage across multiple rounds.""" + total_usage: dict = {"usage": None} + + # Round 1: 100 tokens + round1_usage = LLMUsage( + prompt_tokens=70, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.07"), + completion_tokens=30, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.06"), + total_tokens=100, + total_price=Decimal("0.13"), + currency="USD", + latency=0.3, + ) + agent_pattern._accumulate_usage(total_usage, round1_usage) + assert total_usage["usage"].total_tokens == 100 + + # Round 2: 150 tokens + round2_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.4, + ) + agent_pattern._accumulate_usage(total_usage, round2_usage) + assert total_usage["usage"].total_tokens == 250 # 100 + 150 + + # Round 3: 200 tokens + round3_usage = LLMUsage( + prompt_tokens=130, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.13"), + completion_tokens=70, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.14"), + total_tokens=200, + total_price=Decimal("0.27"), + currency="USD", + latency=0.5, + ) + agent_pattern._accumulate_usage(total_usage, round3_usage) + assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200 + + +class TestCreateLog: + """Tests for _create_log method.""" + + def test_create_log_with_label_and_status(self, agent_pattern): + """Test creating a log with label and status.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"key": "value"}, + ) + + assert log.label == "ROUND 1" + assert log.log_type == AgentLog.LogType.ROUND + assert log.status == AgentLog.LogStatus.START + assert log.data == {"key": "value"} + assert log.parent_id is None + + def test_create_log_with_parent_id(self, agent_pattern): + """Test creating a log with parent_id.""" + parent_log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + child_log = agent_pattern._create_log( + label="CALL tool", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={}, + parent_id=parent_log.id, + ) + + assert child_log.parent_id == parent_log.id + assert child_log.log_type == AgentLog.LogType.TOOL_CALL + + +class TestFinishLog: + """Tests for _finish_log method.""" + + def test_finish_log_updates_status(self, agent_pattern): + """Test that finish_log updates status to SUCCESS.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + finished_log = agent_pattern._finish_log(log, data={"result": "done"}) + + assert finished_log.status == AgentLog.LogStatus.SUCCESS + assert finished_log.data == {"result": "done"} + + def test_finish_log_adds_usage_metadata(self, agent_pattern): + """Test that finish_log adds usage to metadata.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + finished_log = agent_pattern._finish_log(log, usage=usage) + + assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150 + assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2") + assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD" + assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage + + +class TestFindToolByName: + """Tests for _find_tool_by_name method.""" + + def test_find_existing_tool(self, mock_model_instance, mock_context): + """Test finding an existing tool by name.""" + mock_tool = MagicMock() + mock_tool.entity.identity.name = "test_tool" + + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + found_tool = pattern._find_tool_by_name("test_tool") + assert found_tool == mock_tool + + def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context): + """Test that finding a nonexistent tool returns None.""" + mock_tool = MagicMock() + mock_tool.entity.identity.name = "test_tool" + + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + found_tool = pattern._find_tool_by_name("nonexistent_tool") + assert found_tool is None + + +class TestMaxIterationsCapping: + """Tests for max_iterations capping.""" + + def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context): + """Test that max_iterations is capped at 99.""" + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=150, + ) + + assert pattern.max_iterations == 99 + + def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context): + """Test that max_iterations is not capped when under 99.""" + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=50, + ) + + assert pattern.max_iterations == 50 diff --git a/api/tests/unit_tests/core/agent/patterns/test_function_call.py b/api/tests/unit_tests/core/agent/patterns/test_function_call.py new file mode 100644 index 0000000000..6b3600dbbf --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_function_call.py @@ -0,0 +1,332 @@ +"""Tests for FunctionCallStrategy.""" + +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentLog, ExecutionContext +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import ( + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def mock_tool(): + """Create a mock tool.""" + tool = MagicMock() + tool.entity.identity.name = "test_tool" + tool.to_prompt_message_tool.return_value = PromptMessageTool( + name="test_tool", + description="A test tool", + parameters={ + "type": "object", + "properties": {"param1": {"type": "string", "description": "A parameter"}}, + "required": ["param1"], + }, + ) + return tool + + +class TestFunctionCallStrategyInit: + """Tests for FunctionCallStrategy initialization.""" + + def test_initialization(self, mock_model_instance, mock_context, mock_tool): + """Test basic initialization.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + max_iterations=10, + ) + + assert strategy.model_instance == mock_model_instance + assert strategy.context == mock_context + assert strategy.max_iterations == 10 + assert len(strategy.tools) == 1 + + def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool): + """Test initialization with tool_invoke_hook.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + mock_hook = MagicMock() + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=mock_hook, + ) + + assert strategy.tool_invoke_hook == mock_hook + + +class TestConvertToolsToPromptFormat: + """Tests for _convert_tools_to_prompt_format method.""" + + def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool): + """Test that _convert_tools_to_prompt_format returns PromptMessageTool list.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + tools = strategy._convert_tools_to_prompt_format() + + assert len(tools) == 1 + assert isinstance(tools[0], PromptMessageTool) + assert tools[0].name == "test_tool" + + def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context): + """Test that _convert_tools_to_prompt_format returns empty list when no tools.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + tools = strategy._convert_tools_to_prompt_format() + + assert tools == [] + + +class TestAgentLogGeneration: + """Tests for AgentLog generation during run.""" + + def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool): + """Test that round logs have correct structure.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + max_iterations=1, + ) + + # Create a round log + round_log = strategy._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"inputs": {"query": "test"}}, + ) + + assert round_log.label == "ROUND 1" + assert round_log.log_type == AgentLog.LogType.ROUND + assert round_log.status == AgentLog.LogStatus.START + assert "inputs" in round_log.data + + def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool): + """Test that tool call logs have correct structure.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + # Create a parent round log + round_log = strategy._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + # Create a tool call log + tool_log = strategy._create_log( + label="CALL test_tool", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}}, + parent_id=round_log.id, + ) + + assert tool_log.label == "CALL test_tool" + assert tool_log.log_type == AgentLog.LogType.TOOL_CALL + assert tool_log.parent_id == round_log.id + assert tool_log.data["tool_name"] == "test_tool" + + +class TestToolInvocation: + """Tests for tool invocation.""" + + def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool): + """Test that tool invocation uses hook when provided.""" + from core.agent.patterns.function_call import FunctionCallStrategy + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_hook = MagicMock() + mock_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={"tool_provider_type": "test", "tool_provider": "test_id"}, + ) + mock_hook.return_value = ("Tool result", ["file-1"], mock_meta) + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=mock_hook, + ) + + result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool") + + mock_hook.assert_called_once() + assert result == "Tool result" + assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list + assert meta == mock_meta + + def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool): + """Test that tool_invoke_hook is None when not provided.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=None, + ) + + # Verify that tool_invoke_hook is None + assert strategy.tool_invoke_hook is None + + +class TestUsageTracking: + """Tests for usage tracking across rounds.""" + + def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context): + """Test that round usage is tracked separately from total.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + # Simulate two rounds of usage + total_usage: dict = {"usage": None} + round1_usage: dict = {"usage": None} + round2_usage: dict = {"usage": None} + + # Round 1 + usage1 = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + strategy._accumulate_usage(round1_usage, usage1) + strategy._accumulate_usage(total_usage, usage1) + + # Round 2 + usage2 = LLMUsage( + prompt_tokens=200, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.2"), + completion_tokens=100, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.2"), + total_tokens=300, + total_price=Decimal("0.4"), + currency="USD", + latency=0.5, + ) + strategy._accumulate_usage(round2_usage, usage2) + strategy._accumulate_usage(total_usage, usage2) + + # Verify round usage is separate + assert round1_usage["usage"].total_tokens == 150 + assert round2_usage["usage"].total_tokens == 300 + # Verify total is accumulated + assert total_usage["usage"].total_tokens == 450 + + +class TestPromptMessageHandling: + """Tests for prompt message handling.""" + + def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool): + """Test that messages include system and user prompts.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="You are a helpful assistant."), + UserPromptMessage(content="Hello"), + ] + + # Just verify the messages can be processed + assert len(messages) == 2 + assert isinstance(messages[0], SystemPromptMessage) + assert isinstance(messages[1], UserPromptMessage) + + def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool): + """Test that assistant messages can contain tool calls.""" + from core.model_runtime.entities.message_entities import AssistantPromptMessage + + tool_call = AssistantPromptMessage.ToolCall( + id="call_123", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name="test_tool", + arguments='{"param1": "value1"}', + ), + ) + + assistant_message = AssistantPromptMessage( + content="I'll help you with that.", + tool_calls=[tool_call], + ) + + assert len(assistant_message.tool_calls) == 1 + assert assistant_message.tool_calls[0].function.name == "test_tool" diff --git a/api/tests/unit_tests/core/agent/patterns/test_react.py b/api/tests/unit_tests/core/agent/patterns/test_react.py new file mode 100644 index 0000000000..a942ba6100 --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_react.py @@ -0,0 +1,224 @@ +"""Tests for ReActStrategy.""" + +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import ExecutionContext +from core.agent.patterns.react import ReActStrategy +from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def mock_tool(): + """Create a mock tool.""" + from core.model_runtime.entities.message_entities import PromptMessageTool + + tool = MagicMock() + tool.entity.identity.name = "test_tool" + tool.entity.identity.provider = "test_provider" + + # Use real PromptMessageTool for proper serialization + prompt_tool = PromptMessageTool( + name="test_tool", + description="A test tool", + parameters={"type": "object", "properties": {}}, + ) + tool.to_prompt_message_tool.return_value = prompt_tool + + return tool + + +class TestReActStrategyInit: + """Tests for ReActStrategy initialization.""" + + def test_init_with_instruction(self, mock_model_instance, mock_context): + """Test that instruction is stored correctly.""" + instruction = "You are a helpful assistant." + + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + instruction=instruction, + ) + + assert strategy.instruction == instruction + + def test_init_with_empty_instruction(self, mock_model_instance, mock_context): + """Test that empty instruction is handled correctly.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + assert strategy.instruction == "" + + +class TestBuildPromptWithReactFormat: + """Tests for _build_prompt_with_react_format method.""" + + def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool): + """Test that {{tools}} placeholder is replaced.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + system_content = "You have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + UserPromptMessage(content="Hello"), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + # The tools placeholder should be replaced with JSON + assert "{{tools}}" not in result[0].content + assert "test_tool" in result[0].content + + def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool): + """Test that {{tool_names}} placeholder is replaced.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + system_content = "Valid actions: {{tool_names}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + assert "{{tool_names}}" not in result[0].content + assert '"test_tool"' in result[0].content + + def test_replace_instruction_placeholder(self, mock_model_instance, mock_context): + """Test that {{instruction}} placeholder is replaced.""" + instruction = "You are a helpful coding assistant." + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + instruction=instruction, + ) + + system_content = "{{instruction}}\n\nYou have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True, instruction) + + assert "{{instruction}}" not in result[0].content + assert instruction in result[0].content + + def test_no_tools_available_message(self, mock_model_instance, mock_context): + """Test that 'No tools available' is shown when include_tools is False.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + system_content = "You have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], False) + + assert "No tools available" in result[0].content + + def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context): + """Test that agent scratchpad is appended as AssistantPromptMessage.""" + from core.agent.entities import AgentScratchpadUnit + from core.model_runtime.entities import AssistantPromptMessage + + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="System prompt"), + UserPromptMessage(content="User query"), + ] + + scratchpad = [ + AgentScratchpadUnit( + thought="I need to search for information", + action_str='{"action": "search", "action_input": "query"}', + observation="Search results here", + ) + ] + + result = strategy._build_prompt_with_react_format(messages, scratchpad, True) + + # The last message should be an AssistantPromptMessage with scratchpad content + assert len(result) == 3 + assert isinstance(result[-1], AssistantPromptMessage) + assert "I need to search for information" in result[-1].content + assert "Search results here" in result[-1].content + + def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context): + """Test that empty scratchpad doesn't add extra message.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="System prompt"), + UserPromptMessage(content="User query"), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + # Should only have the original 2 messages + assert len(result) == 2 + + def test_original_messages_not_modified(self, mock_model_instance, mock_context): + """Test that original messages list is not modified.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + original_content = "Original system prompt {{tools}}" + messages = [ + SystemPromptMessage(content=original_content), + ] + + strategy._build_prompt_with_react_format(messages, [], True) + + # Original message should not be modified + assert messages[0].content == original_content diff --git a/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py new file mode 100644 index 0000000000..07b9df2acf --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py @@ -0,0 +1,203 @@ +"""Tests for StrategyFactory.""" + +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentEntity, ExecutionContext +from core.agent.patterns.function_call import FunctionCallStrategy +from core.agent.patterns.react import ReActStrategy +from core.agent.patterns.strategy_factory import StrategyFactory +from core.model_runtime.entities.model_entities import ModelFeature + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +class TestStrategyFactory: + """Tests for StrategyFactory.create_strategy method.""" + + def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports TOOL_CALL.""" + model_features = [ModelFeature.TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL.""" + model_features = [ModelFeature.MULTI_TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL.""" + model_features = [ModelFeature.STREAM_TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context): + """Test that ReActStrategy is created when model doesn't support tool calling.""" + model_features = [ModelFeature.VISION] # Only vision, no tool calling + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, ReActStrategy) + + def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context): + """Test that ReActStrategy is created when model has no features.""" + model_features: list[ModelFeature] = [] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, ReActStrategy) + + def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context): + """Test explicit FUNCTION_CALLING strategy selection with model support.""" + model_features = [ModelFeature.TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_explicit_function_calling_strategy_without_support_falls_back_to_react( + self, mock_model_instance, mock_context + ): + """Test that explicit FUNCTION_CALLING falls back to ReAct when not supported.""" + model_features: list[ModelFeature] = [] # No tool calling support + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, + ) + + # Should fall back to ReAct since FC is not supported + assert isinstance(strategy, ReActStrategy) + + def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context): + """Test explicit CHAIN_OF_THOUGHT strategy selection.""" + model_features = [ModelFeature.TOOL_CALL] # Even with tool call support + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + ) + + assert isinstance(strategy, ReActStrategy) + + def test_react_strategy_with_instruction(self, mock_model_instance, mock_context): + """Test that ReActStrategy receives instruction parameter.""" + model_features: list[ModelFeature] = [] + instruction = "You are a helpful assistant." + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + instruction=instruction, + ) + + assert isinstance(strategy, ReActStrategy) + assert strategy.instruction == instruction + + def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context): + """Test that max_iterations is passed to the strategy.""" + model_features = [ModelFeature.TOOL_CALL] + max_iterations = 5 + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + max_iterations=max_iterations, + ) + + assert strategy.max_iterations == max_iterations + + def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context): + """Test that tool_invoke_hook is passed to the strategy.""" + model_features = [ModelFeature.TOOL_CALL] + mock_hook = MagicMock() + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + tool_invoke_hook=mock_hook, + ) + + assert strategy.tool_invoke_hook == mock_hook diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py new file mode 100644 index 0000000000..d9301ccfe0 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py @@ -0,0 +1,388 @@ +"""Tests for AgentAppRunner.""" + +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult +from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.llm_entities import LLMUsage + + +class TestOrganizePromptMessages: + """Tests for _organize_prompt_messages method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + # We'll patch the class to avoid complex initialization + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + + # Set up required attributes + runner.config = MagicMock(spec=AgentEntity) + runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING + runner.config.prompt = None + + runner.app_config = MagicMock() + runner.app_config.prompt_template = MagicMock() + runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant." + + runner.history_prompt_messages = [] + runner.query = "Hello" + runner._current_thoughts = [] + runner.files = [] + runner.model_config = MagicMock() + runner.memory = None + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.file_upload_config = None + + return runner + + def test_function_calling_uses_simple_prompt(self, mock_runner): + """Test that function calling strategy uses simple_prompt_template.""" + mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="You are a helpful assistant.") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with simple_prompt_template + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "You are a helpful assistant." + + def test_chain_of_thought_uses_agent_prompt(self, mock_runner): + """Test that chain of thought strategy uses agent prompt template.""" + mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + mock_runner.config.prompt = AgentPromptEntity( + first_prompt="ReAct prompt template with {{tools}}", + next_iteration="Continue...", + ) + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="ReAct prompt template with {{tools}}") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with agent prompt + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "ReAct prompt template with {{tools}}" + + def test_chain_of_thought_without_prompt_falls_back(self, mock_runner): + """Test that chain of thought without prompt falls back to simple_prompt_template.""" + mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + mock_runner.config.prompt = None + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="You are a helpful assistant.") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with simple_prompt_template + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "You are a helpful assistant." + + +class TestInitSystemMessage: + """Tests for _init_system_message method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + return runner + + def test_empty_messages_with_template(self, mock_runner): + """Test that system message is created when messages are empty.""" + result = mock_runner._init_system_message("System template", []) + + assert len(result) == 1 + assert isinstance(result[0], SystemPromptMessage) + assert result[0].content == "System template" + + def test_empty_messages_without_template(self, mock_runner): + """Test that empty list is returned when no template and no messages.""" + result = mock_runner._init_system_message("", []) + + assert result == [] + + def test_existing_system_message_not_duplicated(self, mock_runner): + """Test that system message is not duplicated if already present.""" + existing_messages = [ + SystemPromptMessage(content="Existing system"), + UserPromptMessage(content="User message"), + ] + + result = mock_runner._init_system_message("New template", existing_messages) + + # Should not insert new system message + assert len(result) == 2 + assert result[0].content == "Existing system" + + def test_system_message_inserted_when_missing(self, mock_runner): + """Test that system message is inserted when first message is not system.""" + existing_messages = [ + UserPromptMessage(content="User message"), + ] + + result = mock_runner._init_system_message("System template", existing_messages) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert result[0].content == "System template" + + +class TestClearUserPromptImageMessages: + """Tests for _clear_user_prompt_image_messages method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + return runner + + def test_text_content_unchanged(self, mock_runner): + """Test that text content is unchanged.""" + messages = [ + UserPromptMessage(content="Plain text message"), + ] + + result = mock_runner._clear_user_prompt_image_messages(messages) + + assert len(result) == 1 + assert result[0].content == "Plain text message" + + def test_original_messages_not_modified(self, mock_runner): + """Test that original messages are not modified (deep copy).""" + from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + TextPromptMessageContent, + ) + + messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Text part"), + ImagePromptMessageContent( + data="http://example.com/image.jpg", + format="url", + mime_type="image/jpeg", + ), + ] + ), + ] + + result = mock_runner._clear_user_prompt_image_messages(messages) + + # Original should still have list content + assert isinstance(messages[0].content, list) + # Result should have string content + assert isinstance(result[0].content, str) + + +class TestToolInvokeHook: + """Tests for _create_tool_invoke_hook method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + + runner.user_id = "test-user" + runner.tenant_id = "test-tenant" + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.trace_manager = None + runner.application_generate_entity.invoke_from = "api" + runner.application_generate_entity.app_config = MagicMock() + runner.application_generate_entity.app_config.app_id = "test-app" + runner.agent_callback = MagicMock() + runner.conversation = MagicMock() + runner.conversation.id = "test-conversation" + runner.queue_manager = MagicMock() + runner._current_message_file_ids = [] + + return runner + + def test_hook_calls_agent_invoke(self, mock_runner): + """Test that the hook calls ToolEngine.agent_invoke.""" + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_message = MagicMock() + mock_message.id = "test-message" + + mock_tool = MagicMock() + mock_tool_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={ + "tool_provider_type": "test_provider", + "tool_provider": "test_id", + }, + ) + + with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine: + mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta) + + hook = mock_runner._create_tool_invoke_hook(mock_message) + result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool") + + # Verify ToolEngine.agent_invoke was called + mock_engine.agent_invoke.assert_called_once() + + # Verify return values + assert result_content == "Tool result" + assert result_files == ["file-1", "file-2"] + assert result_meta == mock_tool_meta + + def test_hook_publishes_file_events(self, mock_runner): + """Test that the hook publishes QueueMessageFileEvent for files.""" + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_message = MagicMock() + mock_message.id = "test-message" + + mock_tool = MagicMock() + mock_tool_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={}, + ) + + with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine: + mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta) + + hook = mock_runner._create_tool_invoke_hook(mock_message) + hook(mock_tool, {}, "test_tool") + + # Verify file events were published + assert mock_runner.queue_manager.publish.call_count == 2 + assert mock_runner._current_message_file_ids == ["file-1", "file-2"] + + +class TestAgentLogProcessing: + """Tests for AgentLog processing in run method.""" + + def test_agent_log_status_enum(self): + """Test AgentLog status enum values.""" + assert AgentLog.LogStatus.START == "start" + assert AgentLog.LogStatus.SUCCESS == "success" + assert AgentLog.LogStatus.ERROR == "error" + + def test_agent_log_metadata_enum(self): + """Test AgentLog metadata enum values.""" + assert AgentLog.LogMetadata.STARTED_AT == "started_at" + assert AgentLog.LogMetadata.FINISHED_AT == "finished_at" + assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time" + assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price" + assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens" + assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage" + + def test_agent_result_structure(self): + """Test AgentResult structure.""" + usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + result = AgentResult( + text="Final answer", + files=[], + usage=usage, + finish_reason="stop", + ) + + assert result.text == "Final answer" + assert result.files == [] + assert result.usage == usage + assert result.finish_reason == "stop" + + +class TestOrganizeUserQuery: + """Tests for _organize_user_query method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + runner.files = [] + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.file_upload_config = None + return runner + + def test_simple_query_without_files(self, mock_runner): + """Test organizing a simple query without files.""" + result = mock_runner._organize_user_query("Hello world", []) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert result[0].content == "Hello world" + + def test_query_with_files(self, mock_runner): + """Test organizing a query with files.""" + from core.file.models import File + + mock_file = MagicMock(spec=File) + mock_runner.files = [mock_file] + + with patch("core.agent.agent_app_runner.file_manager") as mock_fm: + from core.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent( + data="http://example.com/image.jpg", + format="url", + mime_type="image/jpeg", + ) + + result = mock_runner._organize_user_query("Describe this image", []) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert isinstance(result[0].content, list) + assert len(result[0].content) == 2 # Image + Text diff --git a/api/tests/unit_tests/core/agent/test_entities.py b/api/tests/unit_tests/core/agent/test_entities.py new file mode 100644 index 0000000000..5136f48aab --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_entities.py @@ -0,0 +1,191 @@ +"""Tests for agent entities.""" + +from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext + + +class TestExecutionContext: + """Tests for ExecutionContext entity.""" + + def test_create_with_all_fields(self): + """Test creating ExecutionContext with all fields.""" + context = ExecutionContext( + user_id="user-123", + app_id="app-456", + conversation_id="conv-789", + message_id="msg-012", + tenant_id="tenant-345", + ) + + assert context.user_id == "user-123" + assert context.app_id == "app-456" + assert context.conversation_id == "conv-789" + assert context.message_id == "msg-012" + assert context.tenant_id == "tenant-345" + + def test_create_minimal(self): + """Test creating minimal ExecutionContext.""" + context = ExecutionContext.create_minimal(user_id="user-123") + + assert context.user_id == "user-123" + assert context.app_id is None + assert context.conversation_id is None + assert context.message_id is None + assert context.tenant_id is None + + def test_to_dict(self): + """Test converting ExecutionContext to dictionary.""" + context = ExecutionContext( + user_id="user-123", + app_id="app-456", + conversation_id="conv-789", + message_id="msg-012", + tenant_id="tenant-345", + ) + + result = context.to_dict() + + assert result == { + "user_id": "user-123", + "app_id": "app-456", + "conversation_id": "conv-789", + "message_id": "msg-012", + "tenant_id": "tenant-345", + } + + def test_with_updates(self): + """Test creating new context with updates.""" + original = ExecutionContext( + user_id="user-123", + app_id="app-456", + ) + + updated = original.with_updates(message_id="msg-789") + + # Original should be unchanged + assert original.message_id is None + # Updated should have new value + assert updated.message_id == "msg-789" + assert updated.user_id == "user-123" + assert updated.app_id == "app-456" + + +class TestAgentLog: + """Tests for AgentLog entity.""" + + def test_create_log_with_required_fields(self): + """Test creating AgentLog with required fields.""" + log = AgentLog( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"key": "value"}, + ) + + assert log.label == "ROUND 1" + assert log.log_type == AgentLog.LogType.ROUND + assert log.status == AgentLog.LogStatus.START + assert log.data == {"key": "value"} + assert log.id is not None # Auto-generated + assert log.parent_id is None + assert log.error is None + + def test_log_type_enum(self): + """Test LogType enum values.""" + assert AgentLog.LogType.ROUND == "round" + assert AgentLog.LogType.THOUGHT == "thought" + assert AgentLog.LogType.TOOL_CALL == "tool_call" + + def test_log_status_enum(self): + """Test LogStatus enum values.""" + assert AgentLog.LogStatus.START == "start" + assert AgentLog.LogStatus.SUCCESS == "success" + assert AgentLog.LogStatus.ERROR == "error" + + def test_log_metadata_enum(self): + """Test LogMetadata enum values.""" + assert AgentLog.LogMetadata.STARTED_AT == "started_at" + assert AgentLog.LogMetadata.FINISHED_AT == "finished_at" + assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time" + assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price" + assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens" + assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage" + + +class TestAgentScratchpadUnit: + """Tests for AgentScratchpadUnit entity.""" + + def test_is_final_with_final_answer_action(self): + """Test is_final returns True for Final Answer action.""" + unit = AgentScratchpadUnit( + thought="I know the answer", + action=AgentScratchpadUnit.Action( + action_name="Final Answer", + action_input="The answer is 42", + ), + ) + + assert unit.is_final() is True + + def test_is_final_with_tool_action(self): + """Test is_final returns False for tool action.""" + unit = AgentScratchpadUnit( + thought="I need to search", + action=AgentScratchpadUnit.Action( + action_name="search", + action_input={"query": "test"}, + ), + ) + + assert unit.is_final() is False + + def test_is_final_with_no_action(self): + """Test is_final returns True when no action.""" + unit = AgentScratchpadUnit( + thought="Just thinking", + ) + + assert unit.is_final() is True + + def test_action_to_dict(self): + """Test Action.to_dict method.""" + action = AgentScratchpadUnit.Action( + action_name="search", + action_input={"query": "test"}, + ) + + result = action.to_dict() + + assert result == { + "action": "search", + "action_input": {"query": "test"}, + } + + +class TestAgentEntity: + """Tests for AgentEntity.""" + + def test_strategy_enum(self): + """Test Strategy enum values.""" + assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought" + assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling" + + def test_create_with_prompt(self): + """Test creating AgentEntity with prompt.""" + prompt = AgentPromptEntity( + first_prompt="You are a helpful assistant.", + next_iteration="Continue thinking...", + ) + + entity = AgentEntity( + provider="openai", + model="gpt-4", + strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + prompt=prompt, + max_iteration=5, + ) + + assert entity.provider == "openai" + assert entity.model == "gpt-4" + assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT + assert entity.prompt == prompt + assert entity.max_iteration == 5 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py new file mode 100644 index 0000000000..388496ce1d --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py @@ -0,0 +1,169 @@ +"""Tests for ResponseStreamCoordinator object field streaming.""" + +from unittest.mock import MagicMock + +from core.workflow.enums import NodeType +from core.workflow.graph import Graph +from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from core.workflow.graph_events import ChunkType, NodeRunStreamChunkEvent +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.runtime import VariablePool + + +class TestResponseCoordinatorObjectStreaming: + """Test streaming of object-type variables with child fields.""" + + def test_object_field_streaming(self): + """Test that when selecting an object variable, all child field streams are forwarded.""" + # Create mock graph and variable pool + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + # Mock nodes + llm_node = MagicMock() + llm_node.id = "llm_node" + llm_node.node_type = NodeType.LLM + llm_node.execution_type = MagicMock() + llm_node.blocks_variable_output = MagicMock(return_value=False) + + response_node = MagicMock() + response_node.id = "response_node" + response_node.node_type = NodeType.ANSWER + response_node.execution_type = MagicMock() + response_node.blocks_variable_output = MagicMock(return_value=False) + + # Mock template for response node + response_node.node_data = MagicMock(spec=BaseNodeData) + response_node.node_data.answer = "{{#llm_node.generation#}}" + + graph.nodes = { + "llm_node": llm_node, + "response_node": response_node, + } + graph.root_node = llm_node + graph.get_outgoing_edges = MagicMock(return_value=[]) + + # Create coordinator + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Track execution + coordinator.track_node_execution("llm_node", "exec_123") + coordinator.track_node_execution("response_node", "exec_456") + + # Simulate streaming events for child fields of generation object + # 1. Content stream + content_event_1 = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "content"], + chunk="Hello", + is_final=False, + chunk_type=ChunkType.TEXT, + ) + content_event_2 = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "content"], + chunk=" world", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + # 2. Tool call stream + tool_call_event = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "tool_calls"], + chunk='{"query": "test"}', + is_final=True, + chunk_type=ChunkType.TOOL_CALL, + tool_call_id="call_123", + tool_name="search", + tool_arguments='{"query": "test"}', + ) + + # 3. Tool result stream + tool_result_event = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "tool_results"], + chunk="Found 10 results", + is_final=True, + chunk_type=ChunkType.TOOL_RESULT, + tool_call_id="call_123", + tool_name="search", + tool_files=[], + tool_error=None, + ) + + # Intercept these events + coordinator.intercept_event(content_event_1) + coordinator.intercept_event(tool_call_event) + coordinator.intercept_event(tool_result_event) + coordinator.intercept_event(content_event_2) + + # Verify that all child streams are buffered + assert ("llm_node", "generation", "content") in coordinator._stream_buffers + assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers + assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers + + # Verify we can find child streams + child_streams = coordinator._find_child_streams(["llm_node", "generation"]) + assert len(child_streams) == 3 + assert ("llm_node", "generation", "content") in child_streams + assert ("llm_node", "generation", "tool_calls") in child_streams + assert ("llm_node", "generation", "tool_results") in child_streams + + def test_find_child_streams(self): + """Test the _find_child_streams method.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Add some mock streams + coordinator._stream_buffers = { + ("node1", "generation", "content"): [], + ("node1", "generation", "tool_calls"): [], + ("node1", "generation", "thought"): [], + ("node1", "text"): [], # Not a child of generation + ("node2", "generation", "content"): [], # Different node + } + + # Find children of node1.generation + children = coordinator._find_child_streams(["node1", "generation"]) + + assert len(children) == 3 + assert ("node1", "generation", "content") in children + assert ("node1", "generation", "tool_calls") in children + assert ("node1", "generation", "thought") in children + assert ("node1", "text") not in children + assert ("node2", "generation", "content") not in children + + def test_find_child_streams_with_closed_streams(self): + """Test that _find_child_streams also considers closed streams.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Add some streams - some buffered, some closed + coordinator._stream_buffers = { + ("node1", "generation", "content"): [], + } + coordinator._closed_streams = { + ("node1", "generation", "tool_calls"), + ("node1", "generation", "thought"), + } + + # Should find all children regardless of whether they're in buffers or closed + children = coordinator._find_child_streams(["node1", "generation"]) + + assert len(children) == 3 + assert ("node1", "generation", "content") in children + assert ("node1", "generation", "tool_calls") in children + assert ("node1", "generation", "thought") in children diff --git a/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py new file mode 100644 index 0000000000..498d43905e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py @@ -0,0 +1,336 @@ +"""Tests for StreamChunkEvent and its subclasses.""" + +from core.workflow.node_events import ( + ChunkType, + StreamChunkEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, +) + + +class TestChunkType: + """Tests for ChunkType enum.""" + + def test_chunk_type_values(self): + """Test that ChunkType has expected values.""" + assert ChunkType.TEXT == "text" + assert ChunkType.TOOL_CALL == "tool_call" + assert ChunkType.TOOL_RESULT == "tool_result" + assert ChunkType.THOUGHT == "thought" + + def test_chunk_type_is_str_enum(self): + """Test that ChunkType values are strings.""" + for chunk_type in ChunkType: + assert isinstance(chunk_type.value, str) + + +class TestStreamChunkEvent: + """Tests for base StreamChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating StreamChunkEvent with required fields.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="Hello", + ) + + assert event.selector == ["node1", "text"] + assert event.chunk == "Hello" + assert event.is_final is False + assert event.chunk_type == ChunkType.TEXT + + def test_create_with_all_fields(self): + """Test creating StreamChunkEvent with all fields.""" + event = StreamChunkEvent( + selector=["node1", "output"], + chunk="World", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + assert event.selector == ["node1", "output"] + assert event.chunk == "World" + assert event.is_final is True + assert event.chunk_type == ChunkType.TEXT + + def test_default_chunk_type_is_text(self): + """Test that default chunk_type is TEXT.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="test", + ) + + assert event.chunk_type == ChunkType.TEXT + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="Hello", + is_final=True, + ) + + data = event.model_dump() + + assert data["selector"] == ["node1", "text"] + assert data["chunk"] == "Hello" + assert data["is_final"] is True + assert data["chunk_type"] == "text" + + +class TestToolCallChunkEvent: + """Tests for ToolCallChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ToolCallChunkEvent with required fields.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"city": "Beijing"}', + tool_call_id="call_123", + tool_name="weather", + ) + + assert event.selector == ["node1", "tool_calls"] + assert event.chunk == '{"city": "Beijing"}' + assert event.tool_call_id == "call_123" + assert event.tool_name == "weather" + assert event.chunk_type == ChunkType.TOOL_CALL + + def test_chunk_type_is_tool_call(self): + """Test that chunk_type is always TOOL_CALL.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk="", + tool_call_id="call_123", + tool_name="test_tool", + ) + + assert event.chunk_type == ChunkType.TOOL_CALL + + def test_tool_arguments_field(self): + """Test tool_arguments field.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"param": "value"}', + tool_call_id="call_123", + tool_name="test_tool", + tool_arguments='{"param": "value"}', + ) + + assert event.tool_arguments == '{"param": "value"}' + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"city": "Beijing"}', + tool_call_id="call_123", + tool_name="weather", + tool_arguments='{"city": "Beijing"}', + is_final=True, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "tool_call" + assert data["tool_call_id"] == "call_123" + assert data["tool_name"] == "weather" + assert data["tool_arguments"] == '{"city": "Beijing"}' + assert data["is_final"] is True + + +class TestToolResultChunkEvent: + """Tests for ToolResultChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ToolResultChunkEvent with required fields.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="Weather: Sunny, 25°C", + tool_call_id="call_123", + tool_name="weather", + ) + + assert event.selector == ["node1", "tool_results"] + assert event.chunk == "Weather: Sunny, 25°C" + assert event.tool_call_id == "call_123" + assert event.tool_name == "weather" + assert event.chunk_type == ChunkType.TOOL_RESULT + + def test_chunk_type_is_tool_result(self): + """Test that chunk_type is always TOOL_RESULT.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_call_id="call_123", + tool_name="test_tool", + ) + + assert event.chunk_type == ChunkType.TOOL_RESULT + + def test_tool_files_default_empty(self): + """Test that tool_files defaults to empty list.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_call_id="call_123", + tool_name="test_tool", + ) + + assert event.tool_files == [] + + def test_tool_files_with_values(self): + """Test tool_files with file IDs.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_call_id="call_123", + tool_name="test_tool", + tool_files=["file_1", "file_2"], + ) + + assert event.tool_files == ["file_1", "file_2"] + + def test_tool_error_field(self): + """Test tool_error field.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="", + tool_call_id="call_123", + tool_name="test_tool", + tool_error="Tool execution failed", + ) + + assert event.tool_error == "Tool execution failed" + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="Weather: Sunny", + tool_call_id="call_123", + tool_name="weather", + tool_files=["file_1"], + tool_error=None, + is_final=True, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "tool_result" + assert data["tool_call_id"] == "call_123" + assert data["tool_name"] == "weather" + assert data["tool_files"] == ["file_1"] + assert data["tool_error"] is None + assert data["is_final"] is True + + +class TestThoughtChunkEvent: + """Tests for ThoughtChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ThoughtChunkEvent with required fields.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="I need to query the weather...", + ) + + assert event.selector == ["node1", "thought"] + assert event.chunk == "I need to query the weather..." + assert event.chunk_type == ChunkType.THOUGHT + assert event.round_index == 1 # default + + def test_chunk_type_is_thought(self): + """Test that chunk_type is always THOUGHT.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="thinking...", + ) + + assert event.chunk_type == ChunkType.THOUGHT + + def test_round_index_default(self): + """Test that round_index defaults to 1.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="thinking...", + ) + + assert event.round_index == 1 + + def test_round_index_custom(self): + """Test custom round_index.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="Second round thinking...", + round_index=2, + ) + + assert event.round_index == 2 + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="I need to analyze this...", + round_index=3, + is_final=False, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "thought" + assert data["round_index"] == 3 + assert data["chunk"] == "I need to analyze this..." + assert data["is_final"] is False + + +class TestEventInheritance: + """Tests for event inheritance relationships.""" + + def test_tool_call_is_stream_chunk(self): + """Test that ToolCallChunkEvent is a StreamChunkEvent.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk="", + tool_call_id="call_123", + tool_name="test", + ) + + assert isinstance(event, StreamChunkEvent) + + def test_tool_result_is_stream_chunk(self): + """Test that ToolResultChunkEvent is a StreamChunkEvent.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_call_id="call_123", + tool_name="test", + ) + + assert isinstance(event, StreamChunkEvent) + + def test_thought_is_stream_chunk(self): + """Test that ThoughtChunkEvent is a StreamChunkEvent.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="thinking...", + ) + + assert isinstance(event, StreamChunkEvent) + + def test_all_events_have_common_fields(self): + """Test that all events have common StreamChunkEvent fields.""" + events = [ + StreamChunkEvent(selector=["n", "t"], chunk="a"), + ToolCallChunkEvent(selector=["n", "t"], chunk="b", tool_call_id="1", tool_name="t"), + ToolResultChunkEvent(selector=["n", "t"], chunk="c", tool_call_id="1", tool_name="t"), + ThoughtChunkEvent(selector=["n", "t"], chunk="d"), + ] + + for event in events: + assert hasattr(event, "selector") + assert hasattr(event, "chunk") + assert hasattr(event, "is_final") + assert hasattr(event, "chunk_type") diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index ad498ff65b..206f39312a 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -131,6 +131,10 @@ export const LLM_OUTPUT_STRUCT: Var[] = [ variable: 'usage', type: VarType.object, }, + { + variable: 'generation', + type: VarType.object, + }, ] export const KNOWLEDGE_RETRIEVAL_OUTPUT_STRUCT: Var[] = [ diff --git a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx index 8e6993a78d..6cc00d91ee 100644 --- a/web/app/components/workflow/nodes/agent/components/tool-icon.tsx +++ b/web/app/components/workflow/nodes/agent/components/tool-icon.tsx @@ -29,9 +29,9 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => { }) }, [buildInTools, customTools, providerName, workflowTools, mcpTools]) - const providerNameParts = providerName.split('/') - const author = providerNameParts[0] - const name = providerNameParts[1] + const providerNameParts = providerName ? providerName.split('/') : [] + const author = providerNameParts[0] || '' + const name = providerNameParts[1] || providerName || '' const icon = useMemo(() => { if (!isDataReady) return '' if (currentProvider) return currentProvider.icon diff --git a/web/app/components/workflow/nodes/llm/components/tools-config.tsx b/web/app/components/workflow/nodes/llm/components/tools-config.tsx new file mode 100644 index 0000000000..147ff8cd49 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/tools-config.tsx @@ -0,0 +1,58 @@ +import type { FC } from 'react' +import { memo } from 'react' +import { useTranslation } from 'react-i18next' +import MultipleToolSelector from '@/app/components/plugins/plugin-detail-panel/multiple-tool-selector' +import type { NodeOutPutVar } from '@/app/components/workflow/types' +import type { ToolValue } from '@/app/components/workflow/block-selector/types' +import type { Node } from 'reactflow' +import Field from '@/app/components/workflow/nodes/_base/components/field' +import { RiHammerLine } from '@remixicon/react' + +type Props = { + tools?: ToolValue[] + onChange: (tools: ToolValue[]) => void + readonly?: boolean + nodeId?: string + availableVars?: NodeOutPutVar[] + availableNodes?: Node[] +} + +const ToolsConfig: FC = ({ + tools = [], + onChange, + readonly = false, + nodeId = '', + availableVars = [], + availableNodes = [], +}) => { + const { t } = useTranslation() + + return ( + + + {t('workflow.nodes.llm.tools')} + + } + operations={ +
+ {t('workflow.nodes.llm.toolsCount', { count: tools.length })} +
+ } + > + +
+ ) +} + +export default memo(ToolsConfig) diff --git a/web/app/components/workflow/nodes/llm/constants.ts b/web/app/components/workflow/nodes/llm/constants.ts new file mode 100644 index 0000000000..e733ca72c6 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/constants.ts @@ -0,0 +1,41 @@ +// ReAct prompt template for models that don't support tool_call or stream_tool_call +export const REACT_PROMPT_TEMPLATE = `Respond to the human as helpfully and accurately as possible. + +{{instruction}} + +You have access to the following tools: + +{{tools}} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). +Valid "action" values: "Final Answer" or {{tool_names}} + +Provide only ONE action per $JSON_BLOB, as shown: + +\`\`\` +{ + "action": $TOOL_NAME, + "action_input": $ACTION_INPUT +} +\`\`\` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +\`\`\` +$JSON_BLOB +\`\`\` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +\`\`\` +{ + "action": "Final Answer", + "action_input": "Final response to human" +} +\`\`\` + +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation:.` diff --git a/web/app/components/workflow/nodes/llm/default.ts b/web/app/components/workflow/nodes/llm/default.ts index 57033d26a1..e225837d12 100644 --- a/web/app/components/workflow/nodes/llm/default.ts +++ b/web/app/components/workflow/nodes/llm/default.ts @@ -53,6 +53,7 @@ const nodeDefault: NodeDefault = { vision: { enabled: false, }, + tools: [], }, defaultRunInputData: { '#context#': [RETRIEVAL_OUTPUT_STRUCT], diff --git a/web/app/components/workflow/nodes/llm/node.tsx b/web/app/components/workflow/nodes/llm/node.tsx index ce676ba984..0f0b5bf390 100644 --- a/web/app/components/workflow/nodes/llm/node.tsx +++ b/web/app/components/workflow/nodes/llm/node.tsx @@ -1,26 +1,46 @@ import type { FC } from 'react' -import React from 'react' +import React, { useMemo } from 'react' import type { LLMNodeType } from './types' import { useTextGenerationCurrentProviderAndModelAndModelList, } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import type { NodeProps } from '@/app/components/workflow/types' +import { Group, GroupLabel } from '../_base/components/group' +import { ToolIcon } from '../agent/components/tool-icon' +import useConfig from './use-config' +import { useTranslation } from 'react-i18next' const Node: FC> = ({ + id, data, }) => { + const { t } = useTranslation() + const { inputs } = useConfig(id, data) const { provider, name: modelId } = data.model || {} const { textGenerationModelList, } = useTextGenerationCurrentProviderAndModelAndModelList() const hasSetModel = provider && modelId + // Extract tools information + const tools = useMemo(() => { + if (!inputs.tools || inputs.tools.length === 0) + return [] + + // For LLM Node, tools is ToolValue[] + // Each tool has provider_name which is the unique identifier + return inputs.tools.map((tool, index) => ({ + id: `tool-${index}`, + providerName: tool.provider_name, + })) + }, [inputs.tools]) + if (!hasSetModel) return null return ( -
+
{hasSetModel && ( > = ({ readonly /> )} + + {/* Tools display */} + {tools.length > 0 && ( + + {t('workflow.nodes.llm.tools')} + + } + > +
+ {tools.map(tool => ( + + ))} +
+
+ )}
) } diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index bb893b0da7..b39a4ea373 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -18,6 +18,7 @@ import Tooltip from '@/app/components/base/tooltip' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' import StructureOutput from './components/structure-output' import ReasoningFormatConfig from './components/reasoning-format-config' +import ToolsConfig from './components/tools-config' import Switch from '@/app/components/base/switch' import { RiAlertFill, RiQuestionLine } from '@remixicon/react' import { fetchAndMergeValidCompletionParams } from '@/utils/completion-params' @@ -57,12 +58,14 @@ const Panel: FC> = ({ handleVisionResolutionEnabledChange, handleVisionResolutionChange, isModelSupportStructuredOutput, + isModelSupportToolCall, structuredOutputCollapsed, setStructuredOutputCollapsed, handleStructureOutputEnableChange, handleStructureOutputChange, filterJinja2InputVar, handleReasoningFormatChange, + handleToolsChange, } = useConfig(id, data) const model = inputs.model @@ -241,6 +244,26 @@ const Panel: FC> = ({ onConfigChange={handleVisionResolutionChange} /> + {/* Tools configuration */} + + + {/* Show warning when model doesn't support tool call but tools are selected */} + {inputs.tools && inputs.tools.length > 0 && !isModelSupportToolCall && isChatModel && ( +
+ +
+ {t('workflow.nodes.llm.toolsNotSupportedWarning')} +
+
+ )} + {/* Reasoning Format */} > = ({ type='object' description={t(`${i18nPrefix}.outputVars.usage`)} /> + {inputs.structured_output_enabled && ( <> diff --git a/web/app/components/workflow/nodes/llm/types.ts b/web/app/components/workflow/nodes/llm/types.ts index 70dc4d9cc7..6bc3508cd0 100644 --- a/web/app/components/workflow/nodes/llm/types.ts +++ b/web/app/components/workflow/nodes/llm/types.ts @@ -1,4 +1,5 @@ import type { CommonNodeType, Memory, ModelConfig, PromptItem, ValueSelector, Variable, VisionSetting } from '@/app/components/workflow/types' +import type { ToolValue } from '@/app/components/workflow/block-selector/types' export type LLMNodeType = CommonNodeType & { model: ModelConfig @@ -18,6 +19,7 @@ export type LLMNodeType = CommonNodeType & { structured_output_enabled?: boolean structured_output?: StructuredOutput reasoning_format?: 'tagged' | 'separated' + tools?: ToolValue[] } export enum Type { diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index d9b811bb85..55bc65a4d4 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -1,7 +1,8 @@ import { useCallback, useEffect, useRef, useState } from 'react' +import { EditionType, PromptRole, VarType } from '../../types' import { produce } from 'immer' -import { EditionType, VarType } from '../../types' import type { Memory, PromptItem, ValueSelector, Var, Variable } from '../../types' +import type { ToolValue } from '../../block-selector/types' import { useStore } from '../../store' import { useIsChatMode, @@ -18,6 +19,7 @@ import { import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import { checkHasContextBlock, checkHasHistoryBlock, checkHasQueryBlock } from '@/app/components/base/prompt-editor/constants' import useInspectVarsCrud from '@/app/components/workflow/hooks/use-inspect-vars-crud' +import { REACT_PROMPT_TEMPLATE } from './constants' import { AppModeEnum } from '@/types/app' const useConfig = (id: string, payload: LLMNodeType) => { @@ -250,7 +252,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { }, [setInputs]) const handlePromptChange = useCallback((newPrompt: PromptItem[] | PromptItem) => { - const newInputs = produce(inputRef.current, (draft) => { + const newInputs = produce(inputs, (draft) => { draft.prompt_template = newPrompt }) setInputs(newInputs) @@ -283,10 +285,13 @@ const useConfig = (id: string, payload: LLMNodeType) => { // structure output const { data: modelList } = useModelList(ModelTypeEnum.textGeneration) - const isModelSupportStructuredOutput = modelList + const currentModelFeatures = modelList ?.find(provideItem => provideItem.provider === model?.provider) ?.models.find(modelItem => modelItem.model === model?.name) - ?.features?.includes(ModelFeatureEnum.StructuredOutput) + ?.features || [] + + const isModelSupportStructuredOutput = currentModelFeatures.includes(ModelFeatureEnum.StructuredOutput) + const isModelSupportToolCall = currentModelFeatures.includes(ModelFeatureEnum.toolCall) || currentModelFeatures.includes(ModelFeatureEnum.streamToolCall) const [structuredOutputCollapsed, setStructuredOutputCollapsed] = useState(true) const handleStructureOutputEnableChange = useCallback((enabled: boolean) => { @@ -327,6 +332,91 @@ const useConfig = (id: string, payload: LLMNodeType) => { setInputs(newInputs) }, [setInputs]) + const handleToolsChange = useCallback((tools: ToolValue[]) => { + const newInputs = produce(inputs, (draft) => { + draft.tools = tools + }) + setInputs(newInputs) + }, [inputs, setInputs]) + + // Auto-manage ReAct prompt based on model support and tool selection + useEffect(() => { + if (!isChatModel) return + + // Add a small delay to ensure all state updates have settled + const timeoutId = setTimeout(() => { + const promptTemplate = inputs.prompt_template as PromptItem[] + const systemPromptIndex = promptTemplate.findIndex(item => item.role === 'system') + + const shouldHaveReactPrompt = inputs.tools && inputs.tools.length > 0 && !isModelSupportToolCall + + if (shouldHaveReactPrompt) { + // Should have ReAct prompt + let needsAdd = false + if (systemPromptIndex >= 0) { + const currentSystemPrompt = promptTemplate[systemPromptIndex].text + // Check if ReAct prompt is already present by looking for key phrases + needsAdd = !currentSystemPrompt.includes('{{tools}}') && !currentSystemPrompt.includes('{{tool_names}}') + } + else { + needsAdd = true + } + + if (needsAdd) { + const newInputs = produce(inputs, (draft) => { + const draftPromptTemplate = draft.prompt_template as PromptItem[] + const sysPromptIdx = draftPromptTemplate.findIndex(item => item.role === 'system') + + if (sysPromptIdx >= 0) { + // Append ReAct prompt to existing system prompt + draftPromptTemplate[sysPromptIdx].text + = `${draftPromptTemplate[sysPromptIdx].text}\n\n${REACT_PROMPT_TEMPLATE}` + } + else { + // Create new system prompt with ReAct template + draftPromptTemplate.unshift({ + role: PromptRole.system, + text: REACT_PROMPT_TEMPLATE, + }) + } + }) + setInputs(newInputs) + } + } + else { + // Should NOT have ReAct prompt - remove it if present + if (systemPromptIndex >= 0) { + const currentSystemPrompt = promptTemplate[systemPromptIndex].text + const hasReactPrompt = currentSystemPrompt.includes('{{tools}}') || currentSystemPrompt.includes('{{tool_names}}') + + if (hasReactPrompt) { + const newInputs = produce(inputs, (draft) => { + const draftPromptTemplate = draft.prompt_template as PromptItem[] + const sysPromptIdx = draftPromptTemplate.findIndex(item => item.role === 'system') + + if (sysPromptIdx >= 0) { + // Remove ReAct prompt from system prompt + let cleanedText = draftPromptTemplate[sysPromptIdx].text + // Remove the ReAct template + cleanedText = cleanedText.replace(`\n\n${REACT_PROMPT_TEMPLATE}`, '') + cleanedText = cleanedText.replace(REACT_PROMPT_TEMPLATE, '') + + // If system prompt is now empty, remove it entirely + if (cleanedText.trim() === '') + draftPromptTemplate.splice(sysPromptIdx, 1) + else + draftPromptTemplate[sysPromptIdx].text = cleanedText.trim() + } + }) + setInputs(newInputs) + } + } + } + }, 100) // Small delay to let other state updates settle + + return () => clearTimeout(timeoutId) + }, [inputs.tools?.length, isModelSupportToolCall, isChatModel, setInputs]) + const { availableVars, availableNodesWithParent, @@ -362,12 +452,14 @@ const useConfig = (id: string, payload: LLMNodeType) => { handleVisionResolutionEnabledChange, handleVisionResolutionChange, isModelSupportStructuredOutput, + isModelSupportToolCall, handleStructureOutputChange, structuredOutputCollapsed, setStructuredOutputCollapsed, handleStructureOutputEnableChange, filterJinja2InputVar, handleReasoningFormatChange, + handleToolsChange, } } diff --git a/web/app/components/workflow/run/agent-log/agent-log-trigger.tsx b/web/app/components/workflow/run/agent-log/agent-log-trigger.tsx index 85b37d72d6..8376073d9e 100644 --- a/web/app/components/workflow/run/agent-log/agent-log-trigger.tsx +++ b/web/app/components/workflow/run/agent-log/agent-log-trigger.tsx @@ -4,6 +4,7 @@ import type { AgentLogItemWithChildren, NodeTracing, } from '@/types/workflow' +import { BlockEnum } from '@/app/components/workflow/types' type AgentLogTriggerProps = { nodeInfo: NodeTracing @@ -14,9 +15,13 @@ const AgentLogTrigger = ({ onShowAgentOrToolLog, }: AgentLogTriggerProps) => { const { t } = useTranslation() - const { agentLog, execution_metadata } = nodeInfo + const { agentLog, execution_metadata, node_type } = nodeInfo const agentStrategy = execution_metadata?.tool_info?.agent_strategy + // For LLM node, show different label + const isLLMNode = node_type === BlockEnum.LLM + const label = isLLMNode ? t('workflow.nodes.llm.tools').toUpperCase() : t('workflow.nodes.agent.strategy.label') + return (
- {t('workflow.nodes.agent.strategy.label')} + {label}
{ - agentStrategy && ( + !isLLMNode && agentStrategy && (
{agentStrategy}
diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index 33124907f3..485426b3ea 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -96,6 +96,7 @@ const NodePanel: FC = ({ const isRetryNode = hasRetryNode(nodeInfo.node_type) && !!nodeInfo.retryDetail?.length const isAgentNode = nodeInfo.node_type === BlockEnum.Agent && !!nodeInfo.agentLog?.length const isToolNode = nodeInfo.node_type === BlockEnum.Tool && !!nodeInfo.agentLog?.length + const isLLMNode = nodeInfo.node_type === BlockEnum.LLM && !!nodeInfo.agentLog?.length const inputsTitle = useMemo(() => { let text = t('workflow.common.input') @@ -188,7 +189,7 @@ const NodePanel: FC = ({ /> )} { - (isAgentNode || isToolNode) && onShowAgentOrToolLog && ( + (isAgentNode || isToolNode || isLLMNode) && onShowAgentOrToolLog && ( = ({ const isRetryNode = hasRetryNode(nodeInfo?.node_type) && !!nodeInfo?.retryDetail?.length const isAgentNode = nodeInfo?.node_type === BlockEnum.Agent && !!nodeInfo?.agentLog?.length const isToolNode = nodeInfo?.node_type === BlockEnum.Tool && !!nodeInfo?.agentLog?.length + const isLLMNode = nodeInfo?.node_type === BlockEnum.LLM && !!nodeInfo?.agentLog?.length return (
@@ -117,7 +118,7 @@ const ResultPanel: FC = ({ ) } { - (isAgentNode || isToolNode) && handleShowAgentOrToolLog && ( + (isAgentNode || isToolNode || isLLMNode) && handleShowAgentOrToolLog && ( { let { children } = node diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 636537c466..540bb4e43a 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -503,6 +503,9 @@ const translation = { contextTooltip: 'You can import Knowledge as context', notSetContextInPromptTip: 'To enable the context feature, please fill in the context variable in PROMPT.', prompt: 'prompt', + tools: 'Tools', + toolsCount: '{{count}} tools selected', + toolsNotSupportedWarning: 'This model does not support native tool calling. A ReAct prompt template has been automatically added to the system prompt to enable tool usage.', roleDescription: { system: 'Give high level instructions for the conversation', user: 'Provide instructions, queries, or any text-based input to the model', @@ -520,6 +523,7 @@ const translation = { output: 'Generate content', reasoning_content: 'Reasoning Content', usage: 'Model Usage Information', + generation: 'Generation details including reasoning, tool calls and their sequence', }, singleRun: { variable: 'Variable', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index e33941a6cd..3c1d1f1cb8 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -503,6 +503,9 @@ const translation = { contextTooltip: '您可以导入知识库作为上下文', notSetContextInPromptTip: '要启用上下文功能,请在提示中填写上下文变量。', prompt: '提示词', + tools: '工具', + toolsCount: '已选择 {{count}} 个工具', + toolsNotSupportedWarning: '该模型不支持原生工具调用功能。已自动在系统提示词中添加 ReAct 提示模板以启用工具使用。', addMessage: '添加消息', roleDescription: { system: '为对话提供高层指导', @@ -520,6 +523,7 @@ const translation = { output: '生成内容', reasoning_content: '推理内容', usage: '模型用量信息', + generation: '生成详情,包含推理内容、工具调用及其顺序', }, singleRun: { variable: '变量',