diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 12ada8b798..66f4524156 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -202,6 +202,7 @@ message_detail_model = console_ns.model( "status": fields.String, "error": fields.String, "parent_message_id": fields.String, + "generation_detail": fields.Raw, }, ) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index d342f4e661..bb908a8fb1 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -74,6 +74,7 @@ def build_message_model(api_or_ns: Namespace): "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "status": fields.String, "error": fields.String, + "generation_detail": fields.Raw, } return api_or_ns.model("Message", message_fields) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 5c7ea9e69a..51ce024a5b 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -85,6 +85,7 @@ class MessageListApi(WebApiResource): "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, + "generation_detail": fields.Raw, } message_infinite_scroll_pagination_fields = { diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py new file mode 100644 index 0000000000..2ee0a23aab --- /dev/null +++ b/api/core/agent/agent_app_runner.py @@ -0,0 +1,380 @@ +import logging +from collections.abc import Generator +from copy import deepcopy +from typing import Any + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.entities import AgentEntity, AgentLog, AgentResult +from core.agent.patterns.strategy_factory import StrategyFactory +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.file import file_manager +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMUsage, + PromptMessage, + PromptMessageContentType, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine +from models.model import Message + +logger = logging.getLogger(__name__) + + +class AgentAppRunner(BaseAgentRunner): + def _create_tool_invoke_hook(self, message: Message): + """ + Create a tool invoke hook that uses ToolEngine.agent_invoke. + This hook handles file creation and returns proper meta information. + """ + # Get trace manager from app generate entity + trace_manager = self.application_generate_entity.trace_manager + + def tool_invoke_hook( + tool: Tool, tool_args: dict[str, Any], tool_name: str + ) -> tuple[str, list[str], ToolInvokeMeta]: + """Hook that uses agent_invoke for proper file and meta handling.""" + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters=tool_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback, + trace_manager=trace_manager, + app_id=self.application_generate_entity.app_config.app_id, + message_id=message.id, + conversation_id=self.conversation.id, + ) + + # Publish files and track IDs + for message_file_id in message_files: + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), + PublishFrom.APPLICATION_MANAGER, + ) + self._current_message_file_ids.append(message_file_id) + + return tool_invoke_response, message_files, tool_invoke_meta + + return tool_invoke_hook + + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: + """ + Run Agent application + """ + self.query = query + app_generate_entity = self.application_generate_entity + + app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" + + # convert tools into ModelRuntime Tool format + tool_instances, _ = self._init_prompt_tools() + + assert app_config.agent + + # Create tool invoke hook for agent_invoke + tool_invoke_hook = self._create_tool_invoke_hook(message) + + # Get instruction for ReAct strategy + instruction = self.app_config.prompt_template.simple_prompt_template or "" + + # Use factory to create appropriate strategy + strategy = StrategyFactory.create_strategy( + model_features=self.model_features, + model_instance=self.model_instance, + tools=list(tool_instances.values()), + files=list(self.files), + max_iterations=app_config.agent.max_iteration, + context=self.build_execution_context(), + agent_strategy=self.config.strategy, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Initialize state variables + current_agent_thought_id = None + has_published_thought = False + current_tool_name: str | None = None + self._current_message_file_ids: list[str] = [] + + # organize prompt messages + prompt_messages = self._organize_prompt_messages() + + # Run strategy + generator = strategy.run( + prompt_messages=prompt_messages, + model_parameters=app_generate_entity.model_conf.parameters, + stop=app_generate_entity.model_conf.stop, + stream=True, + ) + + # Consume generator and collect result + result: AgentResult | None = None + try: + while True: + try: + output = next(generator) + except StopIteration as e: + # Generator finished, get the return value + result = e.value + break + + if isinstance(output, LLMResultChunk): + # Handle LLM chunk + if current_agent_thought_id and not has_published_thought: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + has_published_thought = True + + yield output + + elif isinstance(output, AgentLog): + # Handle Agent Log using log_type for type-safe dispatch + if output.status == AgentLog.LogStatus.START: + if output.log_type == AgentLog.LogType.ROUND: + # Start of a new round + message_file_ids: list[str] = [] + current_agent_thought_id = self.create_agent_thought( + message_id=message.id, + message="", + tool_name="", + tool_input="", + messages_ids=message_file_ids, + ) + has_published_thought = False + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call start - extract data from structured fields + current_tool_name = output.data.get("tool_name", "") + tool_input = output.data.get("tool_args", {}) + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=current_tool_name, + tool_input=tool_input, + thought=None, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.status == AgentLog.LogStatus.SUCCESS: + if output.log_type == AgentLog.LogType.THOUGHT: + if current_agent_thought_id is None: + continue + + thought_text = output.data.get("thought") + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=thought_text, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call finished + tool_output = output.data.get("output") + # Get meta from strategy output (now properly populated) + tool_meta = output.data.get("meta") + + # Wrap tool_meta with tool_name as key (required by agent_service) + if tool_meta and current_tool_name: + tool_meta = {current_tool_name: tool_meta} + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=None, + observation=tool_output, + tool_invoke_meta=tool_meta, + answer=None, + messages_ids=self._current_message_file_ids, + ) + # Clear message file ids after saving + self._current_message_file_ids = [] + current_tool_name = None + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.log_type == AgentLog.LogType.ROUND: + if current_agent_thought_id is None: + continue + + # Round finished - save LLM usage and answer + llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) + llm_result = output.data.get("llm_result") + final_answer = output.data.get("final_answer") + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=llm_result, + observation=None, + tool_invoke_meta=None, + answer=final_answer, + messages_ids=[], + llm_usage=llm_usage, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + except Exception: + # Re-raise any other exceptions + raise + + # Process final result + if isinstance(result, AgentResult): + final_answer = result.text + usage = result.usage or LLMUsage.empty_usage() + + # Publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Initialize system message + """ + if not prompt_template: + return prompt_messages or [] + + prompt_messages = prompt_messages or [] + + if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = SystemPromptMessage(content=prompt_template) + return prompt_messages + + if not prompt_messages: + return [SystemPromptMessage(content=prompt_template)] + + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + return prompt_messages + + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + As for now, gpt supports both fc and vision at the first iteration. + We need to remove the image messages from the prompt messages at the first iteration. + """ + prompt_messages = deepcopy(prompt_messages) + + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) + + return prompt_messages + + def _organize_prompt_messages(self): + # For ReAct strategy, use the agent prompt template + if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt: + prompt_template = self.config.prompt.first_prompt + else: + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" + + self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) + query_prompt_messages = self._organize_user_query(self.query or "", []) + + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=[*query_prompt_messages, *self._current_thoughts], + history_messages=self.history_prompt_messages, + memory=self.memory, + ).get_prompt() + + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] + if len(self._current_thoughts) != 0: + # clear messages after the first iteration + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + return prompt_messages diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c196dbbdf1..b59a9a3859 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -5,7 +5,7 @@ from typing import Union, cast from sqlalchemy import select -from core.agent.entities import AgentEntity, AgentToolEntity +from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -114,9 +114,20 @@ class BaseAgentRunner(AppRunner): features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] + self.model_features = features self.query: str | None = "" self._current_thoughts: list[PromptMessage] = [] + def build_execution_context(self) -> ExecutionContext: + """Build execution context.""" + return ExecutionContext( + user_id=self.user_id, + app_id=self.app_config.app_id, + conversation_id=self.conversation.id, + message_id=self.message.id, + tenant_id=self.tenant_id, + ) + def _repack_app_generate_entity( self, app_generate_entity: AgentChatAppGenerateEntity ) -> AgentChatAppGenerateEntity: diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py deleted file mode 100644 index b32e35d0ca..0000000000 --- a/api/core/agent/cot_agent_runner.py +++ /dev/null @@ -1,431 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping, Sequence -from typing import Any - -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.entities import AgentScratchpadUnit -from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from core.ops.ops_trace_manager import TraceQueueManager -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from models.model import Message - -logger = logging.getLogger(__name__) - - -class CotAgentRunner(BaseAgentRunner, ABC): - _is_first_iteration = True - _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] - _agent_scratchpad: list[AgentScratchpadUnit] - _instruction: str - _query: str - _prompt_messages_tools: Sequence[PromptMessageTool] - - def run( - self, - message: Message, - query: str, - inputs: Mapping[str, str], - ) -> Generator: - """ - Run Cot agent application - """ - - app_generate_entity = self.application_generate_entity - self._repack_app_generate_entity(app_generate_entity) - self._init_react_state(query) - - trace_manager = app_generate_entity.trace_manager - - # check model mode - if "Observation" not in app_generate_entity.model_conf.stop: - if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append("Observation") - - app_config = self.app_config - assert app_config.agent - - # init instruction - inputs = inputs or {} - instruction = app_config.prompt_template.simple_prompt_template or "" - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) - - iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - - # convert tools into ModelRuntime Tool format - tool_instances, prompt_messages_tools = self._init_prompt_tools() - self._prompt_messages_tools = prompt_messages_tools - - function_call_state = True - llm_usage: dict[str, LLMUsage | None] = {"usage": None} - final_answer = "" - prompt_messages: list = [] # Initialize prompt_messages - agent_thought_id = "" # Initialize agent_thought_id - - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): - if not final_llm_usage_dict["usage"]: - final_llm_usage_dict["usage"] = usage - else: - llm_usage = final_llm_usage_dict["usage"] - llm_usage.prompt_tokens += usage.prompt_tokens - llm_usage.completion_tokens += usage.completion_tokens - llm_usage.total_tokens += usage.total_tokens - llm_usage.prompt_price += usage.prompt_price - llm_usage.completion_price += usage.completion_price - llm_usage.total_price += usage.total_price - - model_instance = self.model_instance - - while function_call_state and iteration_step <= max_iteration_steps: - # continue to run until there is not any tool call - function_call_state = False - - if iteration_step == max_iteration_steps: - # the last iteration, remove all tools - self._prompt_messages_tools = [] - - message_file_ids: list[str] = [] - - agent_thought_id = self.create_agent_thought( - message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids - ) - - if iteration_step > 1: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # recalc llm max tokens - prompt_messages = self._organize_prompt_messages() - self.recalc_llm_max_tokens(self.model_config, prompt_messages) - # invoke model - chunks = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=app_generate_entity.model_conf.parameters, - tools=[], - stop=app_generate_entity.model_conf.stop, - stream=True, - user=self.user_id, - callbacks=[], - ) - - usage_dict: dict[str, LLMUsage | None] = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) - scratchpad = AgentScratchpadUnit( - agent_response="", - thought="", - action_str="", - observation="", - action=None, - ) - - # publish agent thought if it's first iteration - if iteration_step == 1: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - for chunk in react_chunks: - if isinstance(chunk, AgentScratchpadUnit.Action): - action = chunk - # detect action - assert scratchpad.agent_response is not None - scratchpad.agent_response += json.dumps(chunk.model_dump()) - scratchpad.action_str = json.dumps(chunk.model_dump()) - scratchpad.action = action - else: - assert scratchpad.agent_response is not None - scratchpad.agent_response += chunk - assert scratchpad.thought is not None - scratchpad.thought += chunk - yield LLMResultChunk( - model=self.model_config.model, - prompt_messages=prompt_messages, - system_fingerprint="", - delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), - ) - - assert scratchpad.thought is not None - scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" - self._agent_scratchpad.append(scratchpad) - - # get llm usage - if "usage" in usage_dict: - if usage_dict["usage"] is not None: - increase_usage(llm_usage, usage_dict["usage"]) - else: - usage_dict["usage"] = LLMUsage.empty_usage() - - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), - tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, - tool_invoke_meta={}, - thought=scratchpad.thought or "", - observation="", - answer=scratchpad.agent_response or "", - messages_ids=[], - llm_usage=usage_dict["usage"], - ) - - if not scratchpad.is_final(): - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - if not scratchpad.action: - # failed to extract action, return final answer directly - final_answer = "" - else: - if scratchpad.action.action_name.lower() == "final answer": - # action is final answer, return final answer directly - try: - if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False) - elif isinstance(scratchpad.action.action_input, str): - final_answer = scratchpad.action.action_input - else: - final_answer = f"{scratchpad.action.action_input}" - except TypeError: - final_answer = f"{scratchpad.action.action_input}" - else: - function_call_state = True - # action is tool call, invoke tool - tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( - action=scratchpad.action, - tool_instances=tool_instances, - message_file_ids=message_file_ids, - trace_manager=trace_manager, - ) - scratchpad.observation = tool_invoke_response - scratchpad.agent_response = tool_invoke_response - - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=scratchpad.action.action_name, - tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought or "", - observation={scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, - answer=scratchpad.agent_response, - messages_ids=message_file_ids, - llm_usage=usage_dict["usage"], - ) - - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # update prompt tool message - for prompt_tool in self._prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - - iteration_step += 1 - - yield LLMResultChunk( - model=model_instance.model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] - ), - system_fingerprint="", - ) - - # save agent thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name="", - tool_input={}, - tool_invoke_meta={}, - thought=final_answer, - observation={}, - answer=final_answer, - messages_ids=[], - ) - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) - - def _handle_invoke_action( - self, - action: AgentScratchpadUnit.Action, - tool_instances: Mapping[str, Tool], - message_file_ids: list[str], - trace_manager: TraceQueueManager | None = None, - ) -> tuple[str, ToolInvokeMeta]: - """ - handle invoke action - :param action: action - :param tool_instances: tool instances - :param message_file_ids: message file ids - :param trace_manager: trace manager - :return: observation, meta - """ - # action is tool call, invoke tool - tool_call_name = action.action_name - tool_call_args = action.action_input - tool_instance = tool_instances.get(tool_call_name) - - if not tool_instance: - answer = f"there is not a tool named {tool_call_name}" - return answer, ToolInvokeMeta.error_instance(answer) - - if isinstance(tool_call_args, str): - try: - tool_call_args = json.loads(tool_call_args) - except json.JSONDecodeError: - pass - - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback, - trace_manager=trace_manager, - ) - - # publish files - for message_file_id in message_files: - # publish message file - self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER - ) - # add message file ids - message_file_ids.append(message_file_id) - - return tool_invoke_response, tool_invoke_meta - - def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: - """ - convert dict to action - """ - return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: - """ - fill in inputs from external data tools - """ - for key, value in inputs.items(): - try: - instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) - except Exception: - continue - - return instruction - - def _init_react_state(self, query): - """ - init agent scratchpad - """ - self._query = query - self._agent_scratchpad = [] - self._historic_prompt_messages = self._organize_historic_prompt_messages() - - @abstractmethod - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - organize prompt messages - """ - - def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: - """ - format assistant message - """ - message = "" - for scratchpad in agent_scratchpad: - if scratchpad.is_final(): - message += f"Final Answer: {scratchpad.agent_response}" - else: - message += f"Thought: {scratchpad.thought}\n\n" - if scratchpad.action_str: - message += f"Action: {scratchpad.action_str}\n\n" - if scratchpad.observation: - message += f"Observation: {scratchpad.observation}\n\n" - - return message - - def _organize_historic_prompt_messages( - self, current_session_messages: list[PromptMessage] | None = None - ) -> list[PromptMessage]: - """ - organize historic prompt messages - """ - result: list[PromptMessage] = [] - scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit | None = None - - for message in self.history_prompt_messages: - if isinstance(message, AssistantPromptMessage): - if not current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad = AgentScratchpadUnit( - agent_response=message.content, - thought=message.content or "I am thinking about how to help you", - action_str="", - action=None, - observation=None, - ) - scratchpads.append(current_scratchpad) - if message.tool_calls: - try: - current_scratchpad.action = AgentScratchpadUnit.Action( - action_name=message.tool_calls[0].function.name, - action_input=json.loads(message.tool_calls[0].function.arguments), - ) - current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) - except Exception: - logger.exception("Failed to parse tool call from assistant message") - elif isinstance(message, ToolPromptMessage): - if current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad.observation = message.content - else: - raise NotImplementedError("expected str type") - elif isinstance(message, UserPromptMessage): - if scratchpads: - result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - scratchpads = [] - current_scratchpad = None - - result.append(message) - - if scratchpads: - result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - - historic_prompts = AgentHistoryPromptTransform( - model_config=self.model_config, - prompt_messages=current_session_messages or [], - history_messages=result, - memory=self.memory, - ).get_prompt() - return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py deleted file mode 100644 index 4d1d94eadc..0000000000 --- a/api/core/agent/cot_chat_agent_runner.py +++ /dev/null @@ -1,118 +0,0 @@ -import json - -from core.agent.cot_agent_runner import CotAgentRunner -from core.file import file_manager -from core.model_runtime.entities import ( - AssistantPromptMessage, - PromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.model_runtime.utils.encoders import jsonable_encoder - - -class CotChatAgentRunner(CotAgentRunner): - def _organize_system_prompt(self) -> SystemPromptMessage: - """ - Organize system prompt - """ - assert self.app_config.agent - assert self.app_config.agent.prompt - - prompt_entity = self.app_config.agent.prompt - if not prompt_entity: - raise ValueError("Agent prompt configuration is not set") - first_prompt = prompt_entity.first_prompt - - system_prompt = ( - first_prompt.replace("{{instruction}}", self._instruction) - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) - .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) - ) - - return SystemPromptMessage(content=system_prompt) - - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Organize user query - """ - if self.files: - # get image detail config - image_detail_config = ( - self.application_generate_entity.file_upload_config.image_config.detail - if ( - self.application_generate_entity.file_upload_config - and self.application_generate_entity.file_upload_config.image_config - ) - else None - ) - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - for file in self.files: - prompt_message_contents.append( - file_manager.to_prompt_message_content( - file, - image_detail_config=image_detail_config, - ) - ) - prompt_message_contents.append(TextPromptMessageContent(data=query)) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - Organize - """ - # organize system prompt - system_message = self._organize_system_prompt() - - # organize current assistant messages - agent_scratchpad = self._agent_scratchpad - if not agent_scratchpad: - assistant_messages = [] - else: - assistant_message = AssistantPromptMessage(content="") - assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str - for unit in agent_scratchpad: - if unit.is_final(): - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Final Answer: {unit.agent_response}" - else: - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Thought: {unit.thought}\n\n" - if unit.action_str: - assistant_message.content += f"Action: {unit.action_str}\n\n" - if unit.observation: - assistant_message.content += f"Observation: {unit.observation}\n\n" - - assistant_messages = [assistant_message] - - # query messages - query_messages = self._organize_user_query(self._query, []) - - if assistant_messages: - # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages( - [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] - ) - messages = [ - system_message, - *historic_messages, - *query_messages, - *assistant_messages, - UserPromptMessage(content="continue"), - ] - else: - # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages]) - messages = [system_message, *historic_messages, *query_messages] - - # join all messages - return messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py deleted file mode 100644 index da9a001d84..0000000000 --- a/api/core/agent/cot_completion_agent_runner.py +++ /dev/null @@ -1,87 +0,0 @@ -import json - -from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.utils.encoders import jsonable_encoder - - -class CotCompletionAgentRunner(CotAgentRunner): - def _organize_instruction_prompt(self) -> str: - """ - Organize instruction prompt - """ - if self.app_config.agent is None: - raise ValueError("Agent configuration is not set") - prompt_entity = self.app_config.agent.prompt - if prompt_entity is None: - raise ValueError("prompt entity is not set") - first_prompt = prompt_entity.first_prompt - - system_prompt = ( - first_prompt.replace("{{instruction}}", self._instruction) - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) - .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) - ) - - return system_prompt - - def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str: - """ - Organize historic prompt - """ - historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages) - historic_prompt = "" - - for message in historic_prompt_messages: - if isinstance(message, UserPromptMessage): - historic_prompt += f"Question: {message.content}\n\n" - elif isinstance(message, AssistantPromptMessage): - if isinstance(message.content, str): - historic_prompt += message.content + "\n\n" - elif isinstance(message.content, list): - for content in message.content: - if not isinstance(content, TextPromptMessageContent): - continue - historic_prompt += content.data - - return historic_prompt - - def _organize_prompt_messages(self) -> list[PromptMessage]: - """ - Organize prompt messages - """ - # organize system prompt - system_prompt = self._organize_instruction_prompt() - - # organize historic prompt messages - historic_prompt = self._organize_historic_prompt() - - # organize current assistant messages - agent_scratchpad = self._agent_scratchpad - assistant_prompt = "" - for unit in agent_scratchpad or []: - if unit.is_final(): - assistant_prompt += f"Final Answer: {unit.agent_response}" - else: - assistant_prompt += f"Thought: {unit.thought}\n\n" - if unit.action_str: - assistant_prompt += f"Action: {unit.action_str}\n\n" - if unit.observation: - assistant_prompt += f"Observation: {unit.observation}\n\n" - - # query messages - query_prompt = f"Question: {self._query}" - - # join all messages - prompt = ( - system_prompt.replace("{{historic_messages}}", historic_prompt) - .replace("{{agent_scratchpad}}", assistant_prompt) - .replace("{{query}}", query_prompt) - ) - - return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 220feced1d..56319a14a3 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,3 +1,5 @@ +import uuid +from collections.abc import Mapping from enum import StrEnum from typing import Any, Union @@ -92,3 +94,94 @@ class AgentInvokeMessage(ToolInvokeMessage): """ pass + + +class ExecutionContext(BaseModel): + """Execution context containing trace and audit information. + + This context carries all the IDs and metadata that are not part of + the core business logic but needed for tracing, auditing, and + correlation purposes. + """ + + user_id: str | None = None + app_id: str | None = None + conversation_id: str | None = None + message_id: str | None = None + tenant_id: str | None = None + + @classmethod + def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext": + """Create a minimal context with only essential fields.""" + return cls(user_id=user_id) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for passing to legacy code.""" + return { + "user_id": self.user_id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "message_id": self.message_id, + "tenant_id": self.tenant_id, + } + + def with_updates(self, **kwargs) -> "ExecutionContext": + """Create a new context with updated fields.""" + data = self.to_dict() + data.update(kwargs) + + return ExecutionContext( + user_id=data.get("user_id"), + app_id=data.get("app_id"), + conversation_id=data.get("conversation_id"), + message_id=data.get("message_id"), + tenant_id=data.get("tenant_id"), + ) + + +class AgentLog(BaseModel): + """ + Agent Log. + """ + + class LogType(StrEnum): + """Type of agent log entry.""" + + ROUND = "round" # A complete iteration round + THOUGHT = "thought" # LLM thinking/reasoning + TOOL_CALL = "tool_call" # Tool invocation + + class LogMetadata(StrEnum): + STARTED_AT = "started_at" + FINISHED_AT = "finished_at" + ELAPSED_TIME = "elapsed_time" + TOTAL_PRICE = "total_price" + TOTAL_TOKENS = "total_tokens" + PROVIDER = "provider" + CURRENCY = "currency" + LLM_USAGE = "llm_usage" + + class LogStatus(StrEnum): + START = "start" + ERROR = "error" + SUCCESS = "success" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log") + label: str = Field(..., description="The label of the log") + log_type: LogType = Field(..., description="The type of the log") + parent_id: str | None = Field(default=None, description="Leave empty for root log") + error: str | None = Field(default=None, description="The error message") + status: LogStatus = Field(..., description="The status of the log") + data: Mapping[str, Any] = Field(..., description="Detailed log data") + metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log") + + +class AgentResult(BaseModel): + """ + Agent execution result. + """ + + text: str = Field(default="", description="The generated text") + files: list[Any] = Field(default_factory=list, description="Files produced during execution") + usage: Any | None = Field(default=None, description="LLM usage statistics") + finish_reason: str | None = Field(default=None, description="Reason for completion") diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py deleted file mode 100644 index dcc1326b33..0000000000 --- a/api/core/agent/fc_agent_runner.py +++ /dev/null @@ -1,465 +0,0 @@ -import json -import logging -from collections.abc import Generator -from copy import deepcopy -from typing import Any, Union - -from core.agent.base_agent_runner import BaseAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.file import file_manager -from core.model_runtime.entities import ( - AssistantPromptMessage, - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMUsage, - PromptMessage, - PromptMessageContentType, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from models.model import Message - -logger = logging.getLogger(__name__) - - -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: - """ - Run FunctionCall agent application - """ - self.query = query - app_generate_entity = self.application_generate_entity - - app_config = self.app_config - assert app_config is not None, "app_config is required" - assert app_config.agent is not None, "app_config.agent is required" - - # convert tools into ModelRuntime Tool format - tool_instances, prompt_messages_tools = self._init_prompt_tools() - - assert app_config.agent - - iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 - - # continue to run until there is not any tool call - function_call_state = True - llm_usage: dict[str, LLMUsage | None] = {"usage": None} - final_answer = "" - prompt_messages: list = [] # Initialize prompt_messages - - # get tracing instance - trace_manager = app_generate_entity.trace_manager - - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): - if not final_llm_usage_dict["usage"]: - final_llm_usage_dict["usage"] = usage - else: - llm_usage = final_llm_usage_dict["usage"] - llm_usage.prompt_tokens += usage.prompt_tokens - llm_usage.completion_tokens += usage.completion_tokens - llm_usage.total_tokens += usage.total_tokens - llm_usage.prompt_price += usage.prompt_price - llm_usage.completion_price += usage.completion_price - llm_usage.total_price += usage.total_price - - model_instance = self.model_instance - - while function_call_state and iteration_step <= max_iteration_steps: - function_call_state = False - - if iteration_step == max_iteration_steps: - # the last iteration, remove all tools - prompt_messages_tools = [] - - message_file_ids: list[str] = [] - agent_thought_id = self.create_agent_thought( - message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids - ) - - # recalc llm max tokens - prompt_messages = self._organize_prompt_messages() - self.recalc_llm_max_tokens(self.model_config, prompt_messages) - # invoke model - chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=app_generate_entity.model_conf.parameters, - tools=prompt_messages_tools, - stop=app_generate_entity.model_conf.stop, - stream=self.stream_tool_call, - user=self.user_id, - callbacks=[], - ) - - tool_calls: list[tuple[str, str, dict[str, Any]]] = [] - - # save full response - response = "" - - # save tool call names and inputs - tool_call_names = "" - tool_call_inputs = "" - - current_llm_usage = None - - if isinstance(chunks, Generator): - is_first_chunk = True - for chunk in chunks: - if is_first_chunk: - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - is_first_chunk = False - # check if there is any tool call - if self.check_tool_calls(chunk): - function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk) or []) - tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) - try: - tool_call_inputs = json.dumps( - {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False - ) - except TypeError: - # fallback: force ASCII to handle non-serializable objects - tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - - if chunk.delta.message and chunk.delta.message.content: - if isinstance(chunk.delta.message.content, list): - for content in chunk.delta.message.content: - response += content.data - else: - response += str(chunk.delta.message.content) - - if chunk.delta.usage: - increase_usage(llm_usage, chunk.delta.usage) - current_llm_usage = chunk.delta.usage - - yield chunk - else: - result = chunks - # check if there is any tool call - if self.check_blocking_tool_calls(result): - function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result) or []) - tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) - try: - tool_call_inputs = json.dumps( - {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False - ) - except TypeError: - # fallback: force ASCII to handle non-serializable objects - tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) - - if result.usage: - increase_usage(llm_usage, result.usage) - current_llm_usage = result.usage - - if result.message and result.message.content: - if isinstance(result.message.content, list): - for content in result.message.content: - response += content.data - else: - response += str(result.message.content) - - if not result.message.content: - result.message.content = "" - - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - yield LLMResultChunk( - model=model_instance.model, - prompt_messages=result.prompt_messages, - system_fingerprint=result.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=result.message, - usage=result.usage, - ), - ) - - assistant_message = AssistantPromptMessage(content="", tool_calls=[]) - if tool_calls: - assistant_message.tool_calls = [ - AssistantPromptMessage.ToolCall( - id=tool_call[0], - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) - ), - ) - for tool_call in tool_calls - ] - else: - assistant_message.content = response - - self._current_thoughts.append(assistant_message) - - # save thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name=tool_call_names, - tool_input=tool_call_inputs, - thought=response, - tool_invoke_meta=None, - observation=None, - answer=response, - messages_ids=[], - llm_usage=current_llm_usage, - ) - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - final_answer += response + "\n" - - # call tools - tool_responses = [] - for tool_call_id, tool_call_name, tool_call_args in tool_calls: - tool_instance = tool_instances.get(tool_call_name) - if not tool_instance: - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), - } - else: - # invoke tool - tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( - tool=tool_instance, - tool_parameters=tool_call_args, - user_id=self.user_id, - tenant_id=self.tenant_id, - message=self.message, - invoke_from=self.application_generate_entity.invoke_from, - agent_tool_callback=self.agent_callback, - trace_manager=trace_manager, - app_id=self.application_generate_entity.app_config.app_id, - message_id=self.message.id, - conversation_id=self.conversation.id, - ) - # publish files - for message_file_id in message_files: - # publish message file - self.queue_manager.publish( - QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER - ) - # add message file ids - message_file_ids.append(message_file_id) - - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict(), - } - - tool_responses.append(tool_response) - if tool_response["tool_response"] is not None: - self._current_thoughts.append( - ToolPromptMessage( - content=str(tool_response["tool_response"]), - tool_call_id=tool_call_id, - name=tool_call_name, - ) - ) - - if len(tool_responses) > 0: - # save agent thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name="", - tool_input="", - thought="", - tool_invoke_meta={ - tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses - }, - observation={ - tool_response["tool_call_name"]: tool_response["tool_response"] - for tool_response in tool_responses - }, - answer="", - messages_ids=message_file_ids, - ) - self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER - ) - - # update prompt tool - for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) - - iteration_step += 1 - - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) - - def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: - """ - Check if there is any tool call in llm result chunk - """ - if llm_result_chunk.delta.message.tool_calls: - return True - return False - - def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: - """ - Check if there is any blocking tool call in llm result - """ - if llm_result.message.tool_calls: - return True - return False - - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: - """ - Extract tool calls from llm result chunk - - Returns: - List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] - """ - tool_calls = [] - for prompt_message in llm_result_chunk.delta.message.tool_calls: - args = {} - if prompt_message.function.arguments != "": - args = json.loads(prompt_message.function.arguments) - - tool_calls.append( - ( - prompt_message.id, - prompt_message.function.name, - args, - ) - ) - - return tool_calls - - def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: - """ - Extract blocking tool calls from llm result - - Returns: - List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] - """ - tool_calls = [] - for prompt_message in llm_result.message.tool_calls: - args = {} - if prompt_message.function.arguments != "": - args = json.loads(prompt_message.function.arguments) - - tool_calls.append( - ( - prompt_message.id, - prompt_message.function.name, - args, - ) - ) - - return tool_calls - - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Initialize system message - """ - if not prompt_messages and prompt_template: - return [ - SystemPromptMessage(content=prompt_template), - ] - - if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: - prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - - return prompt_messages or [] - - def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Organize user query - """ - if self.files: - # get image detail config - image_detail_config = ( - self.application_generate_entity.file_upload_config.image_config.detail - if ( - self.application_generate_entity.file_upload_config - and self.application_generate_entity.file_upload_config.image_config - ) - else None - ) - image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - for file in self.files: - prompt_message_contents.append( - file_manager.to_prompt_message_content( - file, - image_detail_config=image_detail_config, - ) - ) - prompt_message_contents.append(TextPromptMessageContent(data=query)) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=query)) - - return prompt_messages - - def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - As for now, gpt supports both fc and vision at the first iteration. - We need to remove the image messages from the prompt messages at the first iteration. - """ - prompt_messages = deepcopy(prompt_messages) - - for prompt_message in prompt_messages: - if isinstance(prompt_message, UserPromptMessage): - if isinstance(prompt_message.content, list): - prompt_message.content = "\n".join( - [ - content.data - if content.type == PromptMessageContentType.TEXT - else "[image]" - if content.type == PromptMessageContentType.IMAGE - else "[file]" - for content in prompt_message.content - ] - ) - - return prompt_messages - - def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or "" - self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query or "", []) - - self.history_prompt_messages = AgentHistoryPromptTransform( - model_config=self.model_config, - prompt_messages=[*query_prompt_messages, *self._current_thoughts], - history_messages=self.history_prompt_messages, - memory=self.memory, - ).get_prompt() - - prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] - if len(self._current_thoughts) != 0: - # clear messages after the first iteration - prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) - return prompt_messages diff --git a/api/core/agent/patterns/README.md b/api/core/agent/patterns/README.md new file mode 100644 index 0000000000..95b1bf87fa --- /dev/null +++ b/api/core/agent/patterns/README.md @@ -0,0 +1,55 @@ +# Agent Patterns + +A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability. + +## Overview + +The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently. + +## Key Features + +- **Dual strategies** + - `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`. + - `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested. +- **Explicit or auto selection** + - `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT). + - Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not. +- **Unified execution contract** + - `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`. + - Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools. +- **Tool handling and hooks** + - Tools convert to `PromptMessageTool` objects before invocation. + - Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`. + - Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs. +- **File-aware arguments** + - Tool args accept `[File: ]` or `[Files: ]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely. +- **ReAct prompt shaping** + - System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders. + - Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history. +- **Observability and accounting** + - Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths. + +## Architecture + +``` +agent/patterns/ +├── base.py # Shared utilities: logging, usage, tool invocation, file handling +├── function_call.py # Native function-calling loop with tool execution +├── react.py # ReAct loop with CoT parsing and scratchpad wiring +└── strategy_factory.py # Strategy selection by model features or explicit override +``` + +## Usage + +- For auto-selection: + - Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params. +- For explicit behavior: + - Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct. +- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`. + +## Integration Points + +- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls. +- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers. +- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments. +- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging. diff --git a/api/core/agent/patterns/__init__.py b/api/core/agent/patterns/__init__.py new file mode 100644 index 0000000000..8a3b125533 --- /dev/null +++ b/api/core/agent/patterns/__init__.py @@ -0,0 +1,19 @@ +"""Agent patterns module. + +This module provides different strategies for agent execution: +- FunctionCallStrategy: Uses native function/tool calling +- ReActStrategy: Uses ReAct (Reasoning + Acting) approach +- StrategyFactory: Factory for creating strategies based on model features +""" + +from .base import AgentPattern +from .function_call import FunctionCallStrategy +from .react import ReActStrategy +from .strategy_factory import StrategyFactory + +__all__ = [ + "AgentPattern", + "FunctionCallStrategy", + "ReActStrategy", + "StrategyFactory", +] diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py new file mode 100644 index 0000000000..9f010bed6a --- /dev/null +++ b/api/core/agent/patterns/base.py @@ -0,0 +1,444 @@ +"""Base class for agent strategies.""" + +from __future__ import annotations + +import json +import re +import time +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator +from typing import TYPE_CHECKING, Any + +from core.agent.entities import AgentLog, AgentResult, ExecutionContext +from core.file import File +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import TextPromptMessageContent +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + +# Type alias for tool invoke hook +# Returns: (response_content, message_file_ids, tool_invoke_meta) +ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]] + + +class AgentPattern(ABC): + """Base class for agent execution strategies.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + ): + """Initialize the agent strategy.""" + self.model_instance = model_instance + self.tools = tools + self.context = context + self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations + self.workflow_call_depth = workflow_call_depth + self.files: list[File] = files + self.tool_invoke_hook = tool_invoke_hook + + @abstractmethod + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the agent strategy.""" + pass + + def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None: + """Accumulate LLM usage statistics.""" + if not total_usage.get("usage"): + # Create a copy to avoid modifying the original + total_usage["usage"] = LLMUsage( + prompt_tokens=delta_usage.prompt_tokens, + prompt_unit_price=delta_usage.prompt_unit_price, + prompt_price_unit=delta_usage.prompt_price_unit, + prompt_price=delta_usage.prompt_price, + completion_tokens=delta_usage.completion_tokens, + completion_unit_price=delta_usage.completion_unit_price, + completion_price_unit=delta_usage.completion_price_unit, + completion_price=delta_usage.completion_price, + total_tokens=delta_usage.total_tokens, + total_price=delta_usage.total_price, + currency=delta_usage.currency, + latency=delta_usage.latency, + ) + else: + current: LLMUsage = total_usage["usage"] + current.prompt_tokens += delta_usage.prompt_tokens + current.completion_tokens += delta_usage.completion_tokens + current.total_tokens += delta_usage.total_tokens + current.prompt_price += delta_usage.prompt_price + current.completion_price += delta_usage.completion_price + current.total_price += delta_usage.total_price + + def _extract_content(self, content: Any) -> str: + """Extract text content from message content.""" + if isinstance(content, list): + # Content items are PromptMessageContentUnionTypes + text_parts = [] + for c in content: + # Check if it's a TextPromptMessageContent (which has data attribute) + if isinstance(c, TextPromptMessageContent): + text_parts.append(c.data) + return "".join(text_parts) + return str(content) + + def _has_tool_calls(self, chunk: LLMResultChunk) -> bool: + """Check if chunk contains tool calls.""" + # LLMResultChunk always has delta attribute + return bool(chunk.delta.message and chunk.delta.message.tool_calls) + + def _has_tool_calls_result(self, result: LLMResult) -> bool: + """Check if result contains tool calls (non-streaming).""" + # LLMResult always has message attribute + return bool(result.message and result.message.tool_calls) + + def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from streaming chunk.""" + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + if chunk.delta.message and chunk.delta.message.tool_calls: + for tool_call in chunk.delta.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from non-streaming result.""" + tool_calls = [] + if result.message and result.message.tool_calls: + for tool_call in result.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_text_from_message(self, message: PromptMessage) -> str: + """Extract text content from a prompt message.""" + # PromptMessage always has content attribute + content = message.content + if isinstance(content, str): + return content + elif isinstance(content, list): + # Extract text from content list + text_parts = [] + for item in content: + if isinstance(item, TextPromptMessageContent): + text_parts.append(item.data) + return " ".join(text_parts) + return "" + + def _create_log( + self, + label: str, + log_type: AgentLog.LogType, + status: AgentLog.LogStatus, + data: dict[str, Any] | None = None, + parent_id: str | None = None, + extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None, + ) -> AgentLog: + """Create a new AgentLog with standard metadata.""" + metadata = { + AgentLog.LogMetadata.STARTED_AT: time.perf_counter(), + } + if extra_metadata: + metadata.update(extra_metadata) + + return AgentLog( + label=label, + log_type=log_type, + status=status, + data=data or {}, + parent_id=parent_id, + metadata=metadata, + ) + + def _finish_log( + self, + log: AgentLog, + data: dict[str, Any] | None = None, + usage: LLMUsage | None = None, + ) -> AgentLog: + """Finish an AgentLog by updating its status and metadata.""" + log.status = AgentLog.LogStatus.SUCCESS + + if data is not None: + log.data = data + + # Calculate elapsed time + started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter()) + finished_at = time.perf_counter() + + # Update metadata + log.metadata = { + **log.metadata, + AgentLog.LogMetadata.FINISHED_AT: finished_at, + AgentLog.LogMetadata.ELAPSED_TIME: finished_at - started_at, + } + + # Add usage information if provided + if usage: + log.metadata.update( + { + AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price, + AgentLog.LogMetadata.CURRENCY: usage.currency, + AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens, + AgentLog.LogMetadata.LLM_USAGE: usage, + } + ) + + return log + + def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]: + """ + Replace file references in tool arguments with actual File objects. + + Args: + tool_args: Dictionary of tool arguments + + Returns: + Updated tool arguments with file references replaced + """ + # Process each argument in the dictionary + processed_args: dict[str, Any] = {} + for key, value in tool_args.items(): + processed_args[key] = self._process_file_reference(value) + return processed_args + + def _process_file_reference(self, data: Any) -> Any: + """ + Recursively process data to replace file references. + Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...]. + + Args: + data: The data to process (can be dict, list, str, or other types) + + Returns: + Processed data with file references replaced + """ + single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$") + multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$") + + if isinstance(data, dict): + # Process dictionary recursively + return {key: self._process_file_reference(value) for key, value in data.items()} + elif isinstance(data, list): + # Process list recursively + return [self._process_file_reference(item) for item in data] + elif isinstance(data, str): + # Check for single file pattern [File: file_id] + single_match = single_file_pattern.match(data.strip()) + if single_match: + file_id = single_match.group(1).strip() + # Find the file in self.files + for file in self.files: + if file.id and str(file.id) == file_id: + return file + # If file not found, return original value + return data + + # Check for multiple files pattern [Files: file_id1, file_id2, ...] + multiple_match = multiple_files_pattern.match(data.strip()) + if multiple_match: + file_ids_str = multiple_match.group(1).strip() + # Split by comma and strip whitespace + file_ids = [fid.strip() for fid in file_ids_str.split(",")] + + # Find all matching files + matched_files: list[File] = [] + for file_id in file_ids: + for file in self.files: + if file.id and str(file.id) == file_id: + matched_files.append(file) + break + + # Return list of files if any were found, otherwise return original + return matched_files or data + + return data + else: + # Return other types as-is + return data + + def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk: + """Create a text chunk for streaming.""" + return LLMResultChunk( + model=self.model_instance.model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + usage=None, + ), + system_fingerprint="", + ) + + def _invoke_tool( + self, + tool_instance: Tool, + tool_args: dict[str, Any], + tool_name: str, + ) -> tuple[str, list[File], ToolInvokeMeta | None]: + """ + Invoke a tool and collect its response. + + Args: + tool_instance: The tool instance to invoke + tool_args: Tool arguments + tool_name: Name of the tool + + Returns: + Tuple of (response_content, tool_files, tool_invoke_meta) + """ + # Process tool_args to replace file references with actual File objects + tool_args = self._replace_file_references(tool_args) + + # If a tool invoke hook is set, use it instead of generic_invoke + if self.tool_invoke_hook: + response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name) + # Note: message_file_ids are stored in DB, we don't convert them to File objects here + # The caller (AgentAppRunner) handles file publishing + return response_content, [], tool_invoke_meta + + # Default: use generic_invoke for workflow scenarios + # Import here to avoid circular import + from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine + + tool_response = ToolEngine().generic_invoke( + tool=tool_instance, + tool_parameters=tool_args, + user_id=self.context.user_id or "", + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, + app_id=self.context.app_id, + conversation_id=self.context.conversation_id, + message_id=self.context.message_id, + ) + + # Collect response and files + response_content = "" + tool_files: list[File] = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(response.message, ToolInvokeMessage.TextMessage) + response_content += response.message.text + + elif response.type == ToolInvokeMessage.MessageType.LINK: + # Handle link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Link: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE: + # Handle image URL messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK: + # Handle image link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK: + # Handle binary file link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + filename = response.meta.get("filename", "file") if response.meta else "file" + response_content += f"[File: {filename} - {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.JSON: + # Handle JSON messages + if isinstance(response.message, ToolInvokeMessage.JsonMessage): + response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2) + + elif response.type == ToolInvokeMessage.MessageType.BLOB: + # Handle blob messages - convert to text representation + if isinstance(response.message, ToolInvokeMessage.BlobMessage): + mime_type = ( + response.meta.get("mime_type", "application/octet-stream") + if response.meta + else "application/octet-stream" + ) + size = len(response.message.blob) + response_content += f"[Binary data: {mime_type}, size: {size} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + # Handle variable messages + if isinstance(response.message, ToolInvokeMessage.VariableMessage): + var_name = response.message.variable_name + var_value = response.message.variable_value + if isinstance(var_value, str): + response_content += var_value + else: + response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]" + + elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: + # Handle blob chunk messages - these are parts of a larger blob + if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage): + response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: + # Handle retriever resources messages + if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage): + response_content += response.message.context + + elif response.type == ToolInvokeMessage.MessageType.FILE: + # Extract file from meta + if response.meta and "file" in response.meta: + file = response.meta["file"] + if isinstance(file, File): + # Check if file is for model or tool output + if response.meta.get("target") == "self": + # File is for model - add to files for next prompt + self.files.append(file) + response_content += f"File '{file.filename}' has been loaded into your context." + else: + # File is tool output + tool_files.append(file) + + return response_content, tool_files, None + + def _find_tool_by_name(self, tool_name: str) -> Tool | None: + """Find a tool instance by its name.""" + for tool in self.tools: + if tool.entity.identity.name == tool_name: + return tool + return None + + def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]: + """Convert tools to prompt message format.""" + prompt_tools: list[PromptMessageTool] = [] + for tool in self.tools: + prompt_tools.append(tool.to_prompt_message_tool()) + return prompt_tools + + def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None: + """Initialize usage tracking with empty usage if not set.""" + if "usage" not in llm_usage or llm_usage["usage"] is None: + llm_usage["usage"] = LLMUsage.empty_usage() diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py new file mode 100644 index 0000000000..a46c5d77f9 --- /dev/null +++ b/api/core/agent/patterns/function_call.py @@ -0,0 +1,295 @@ +"""Function Call strategy implementation.""" + +import json +from collections.abc import Generator +from typing import Any, Union + +from core.agent.entities import AgentLog, AgentResult +from core.file import File +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, +) +from core.tools.entities.tool_entities import ToolInvokeMeta + +from .base import AgentPattern + + +class FunctionCallStrategy(AgentPattern): + """Function Call strategy using model's native tool calling capability.""" + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the function call agent strategy.""" + # Convert tools to prompt format + prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format() + + # Initialize tracking + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + function_call_state: bool = True + total_usage: dict[str, LLMUsage | None] = {"usage": None} + messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy + final_text: str = "" + finish_reason: str | None = None + output_files: list[File] = [] # Track files produced by tools + + while function_call_state and iteration_step <= max_iterations: + function_call_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + yield round_log + # On last iteration, remove tools to force final answer + current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools + model_log = self._create_log( + label=f"{self.model_instance.model} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, LLMUsage | None] = {"usage": None} + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages, + model_parameters=model_parameters, + tools=current_tools, + stop=stop, + stream=stream, + user=self.context.user_id, + callbacks=[], + ) + + # Process response + tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log + ) + messages.append(self._create_assistant_message(response_content, tool_calls)) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update final text if no tool calls (this is likely the final answer) + if not tool_calls: + final_text = response_content + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Process tool calls + tool_outputs: dict[str, str] = {} + if tool_calls: + function_call_state = True + # Execute tools + for tool_call_id, tool_name, tool_args in tool_calls: + tool_response, tool_files, _ = yield from self._handle_tool_call( + tool_name, tool_args, tool_call_id, messages, round_log + ) + tool_outputs[tool_name] = tool_response + # Track files produced by tools + output_files.extend(tool_files) + yield self._finish_log( + round_log, + data={ + "llm_result": response_content, + "tool_calls": [ + {"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls + ] + if tool_calls + else [], + "final_answer": final_text if not function_call_state else None, + }, + usage=round_usage.get("usage"), + ) + iteration_step += 1 + + # Return final result + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, + files=output_files, + usage=total_usage.get("usage") or LLMUsage.empty_usage(), + finish_reason=finish_reason, + ) + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, LLMUsage | None], + start_log: AgentLog, + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[list[tuple[str, str, dict[str, Any]]], str, str | None], + ]: + """Handle LLM response chunks and extract tool calls and content. + + Returns a tuple of (tool_calls, response_content, finish_reason). + """ + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + response_content: str = "" + finish_reason: str | None = None + if isinstance(chunks, Generator): + # Streaming response + for chunk in chunks: + # Extract tool calls + if self._has_tool_calls(chunk): + tool_calls.extend(self._extract_tool_calls(chunk)) + + # Extract content + if chunk.delta.message and chunk.delta.message.content: + response_content += self._extract_content(chunk.delta.message.content) + + # Track usage + if chunk.delta.usage: + self._accumulate_usage(llm_usage, chunk.delta.usage) + + # Capture finish reason + if chunk.delta.finish_reason: + finish_reason = chunk.delta.finish_reason + + yield chunk + else: + # Non-streaming response + result: LLMResult = chunks + + if self._has_tool_calls_result(result): + tool_calls.extend(self._extract_tool_calls_result(result)) + + if result.message and result.message.content: + response_content += self._extract_content(result.message.content) + + if result.usage: + self._accumulate_usage(llm_usage, result.usage) + + # Convert to streaming format + yield LLMResultChunk( + model=result.model, + prompt_messages=result.prompt_messages, + delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), + ) + yield self._finish_log( + start_log, + data={ + "result": response_content, + }, + usage=llm_usage.get("usage"), + ) + return tool_calls, response_content, finish_reason + + def _create_assistant_message( + self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None + ) -> AssistantPromptMessage: + """Create assistant message with tool calls.""" + if tool_calls is None: + return AssistantPromptMessage(content=content) + return AssistantPromptMessage( + content=content or "", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=tc[0], + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])), + ) + for tc in tool_calls + ], + ) + + def _handle_tool_call( + self, + tool_name: str, + tool_args: dict[str, Any], + tool_call_id: str, + messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]: + """Handle a single tool call and return response with files and meta.""" + # Find tool + tool_instance = self._find_tool_by_name(tool_name) + if not tool_instance: + raise ValueError(f"Tool {tool_name} not found") + + # Create tool call log + tool_call_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + ) + yield tool_call_log + + # Invoke tool using base class method with error handling + try: + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name) + + yield self._finish_log( + tool_call_log, + data={ + **tool_call_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + final_content = response_content or "Tool executed successfully" + # Add tool response to messages + messages.append( + ToolPromptMessage( + content=final_content, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + return response_content, tool_files, tool_invoke_meta + except Exception as e: + # Tool invocation failed, yield error log + error_message = str(e) + tool_call_log.status = AgentLog.LogStatus.ERROR + tool_call_log.error = error_message + tool_call_log.data = { + **tool_call_log.data, + "error": error_message, + } + yield tool_call_log + + # Add error message to conversation + error_content = f"Tool execution failed: {error_message}" + messages.append( + ToolPromptMessage( + content=error_content, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + return error_content, [], None diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py new file mode 100644 index 0000000000..81aa7fe3b1 --- /dev/null +++ b/api/core/agent/patterns/react.py @@ -0,0 +1,415 @@ +"""ReAct strategy implementation.""" + +from __future__ import annotations + +import json +from collections.abc import Generator +from typing import TYPE_CHECKING, Any, Union + +from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext +from core.agent.output_parser.cot_output_parser import CotAgentOutputParser +from core.file import File +from core.model_manager import ModelInstance +from core.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + SystemPromptMessage, +) + +from .base import AgentPattern, ToolInvokeHook + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class ReActStrategy(AgentPattern): + """ReAct strategy using reasoning and acting approach.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ): + """Initialize the ReAct strategy with instruction support.""" + super().__init__( + model_instance=model_instance, + tools=tools, + context=context, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + files=files, + tool_invoke_hook=tool_invoke_hook, + ) + self.instruction = instruction + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the ReAct agent strategy.""" + # Initialize tracking + agent_scratchpad: list[AgentScratchpadUnit] = [] + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + react_state: bool = True + total_usage: dict[str, Any] = {"usage": None} + output_files: list[File] = [] # Track files produced by tools + final_text: str = "" + finish_reason: str | None = None + + # Add "Observation" to stop sequences + if "Observation" not in stop: + stop = stop.copy() + stop.append("Observation") + + while react_state and iteration_step <= max_iterations: + react_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + yield round_log + + # Build prompt with/without tools based on iteration + include_tools = iteration_step < max_iterations + current_messages = self._build_prompt_with_react_format( + prompt_messages, agent_scratchpad, include_tools, self.instruction + ) + + model_log = self._create_log( + label=f"{self.model_instance.model} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, Any] = {"usage": None} + + # Use current messages directly (files are handled by base class if needed) + messages_to_use = current_messages + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages_to_use, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=self.context.user_id or "", + callbacks=[], + ) + + # Process response + scratchpad, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log, current_messages + ) + agent_scratchpad.append(scratchpad) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Check if we have an action to execute + if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": + react_state = True + # Execute tool + observation, tool_files = yield from self._handle_tool_call( + scratchpad.action, current_messages, round_log + ) + scratchpad.observation = observation + # Track files produced by tools + output_files.extend(tool_files) + + # Add observation to scratchpad for display + yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages) + else: + # Extract final answer + if scratchpad.action and scratchpad.action.action_input: + final_answer = scratchpad.action.action_input + if isinstance(final_answer, dict): + final_answer = json.dumps(final_answer, ensure_ascii=False) + final_text = str(final_answer) + elif scratchpad.thought: + # If no action but we have thought, use thought as final answer + final_text = scratchpad.thought + + yield self._finish_log( + round_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + "observation": scratchpad.observation or None, + "final_answer": final_text if not react_state else None, + }, + usage=round_usage.get("usage"), + ) + iteration_step += 1 + + # Return final result + + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason + ) + + def _build_prompt_with_react_format( + self, + original_messages: list[PromptMessage], + agent_scratchpad: list[AgentScratchpadUnit], + include_tools: bool = True, + instruction: str = "", + ) -> list[PromptMessage]: + """Build prompt messages with ReAct format.""" + # Copy messages to avoid modifying original + messages = list(original_messages) + + # Find and update the system prompt that should already exist + system_prompt_found = False + for i, msg in enumerate(messages): + if isinstance(msg, SystemPromptMessage): + system_prompt_found = True + # The system prompt from frontend already has the template, just replace placeholders + + # Format tools + tools_str = "" + tool_names = [] + if include_tools and self.tools: + # Convert tools to prompt message tools format + prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools] + tool_names = [tool.name for tool in prompt_tools] + + # Format tools as JSON for comprehensive information + from core.model_runtime.utils.encoders import jsonable_encoder + + tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2) + tool_names_str = ", ".join(f'"{name}"' for name in tool_names) + else: + tools_str = "No tools available" + tool_names_str = "" + + # Replace placeholders in the existing system prompt + updated_content = msg.content + assert isinstance(updated_content, str) + updated_content = updated_content.replace("{{instruction}}", instruction) + updated_content = updated_content.replace("{{tools}}", tools_str) + updated_content = updated_content.replace("{{tool_names}}", tool_names_str) + + # Create new SystemPromptMessage with updated content + messages[i] = SystemPromptMessage(content=updated_content) + break + + # If no system prompt found, that's unexpected but add scratchpad anyway + if not system_prompt_found: + # This shouldn't happen if frontend is working correctly + pass + + # Format agent scratchpad + scratchpad_str = "" + if agent_scratchpad: + scratchpad_parts: list[str] = [] + for unit in agent_scratchpad: + if unit.thought: + scratchpad_parts.append(f"Thought: {unit.thought}") + if unit.action_str: + scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```") + if unit.observation: + scratchpad_parts.append(f"Observation: {unit.observation}") + scratchpad_str = "\n".join(scratchpad_parts) + + # If there's a scratchpad, append it to the last message + if scratchpad_str: + messages.append(AssistantPromptMessage(content=scratchpad_str)) + + return messages + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, Any], + model_log: AgentLog, + current_messages: list[PromptMessage], + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[AgentScratchpadUnit, str | None], + ]: + """Handle LLM response chunks and extract action/thought. + + Returns a tuple of (scratchpad_unit, finish_reason). + """ + usage_dict: dict[str, Any] = {} + + # Convert non-streaming to streaming format if needed + if isinstance(chunks, LLMResult): + # Create a generator from the LLMResult + def result_to_chunks() -> Generator[LLMResultChunk, None, None]: + yield LLMResultChunk( + model=chunks.model, + prompt_messages=chunks.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=chunks.message, + usage=chunks.usage, + finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do + ), + system_fingerprint=chunks.system_fingerprint or "", + ) + + streaming_chunks = result_to_chunks() + else: + streaming_chunks = chunks + + react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict) + + # Initialize scratchpad unit + scratchpad = AgentScratchpadUnit( + agent_response="", + thought="", + action_str="", + observation="", + action=None, + ) + + finish_reason: str | None = None + + # Process chunks + for chunk in react_chunks: + if isinstance(chunk, AgentScratchpadUnit.Action): + # Action detected + action_str = json.dumps(chunk.model_dump()) + scratchpad.agent_response = (scratchpad.agent_response or "") + action_str + scratchpad.action_str = action_str + scratchpad.action = chunk + + yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) + else: + # Text chunk + chunk_text = str(chunk) + scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text + scratchpad.thought = (scratchpad.thought or "") + chunk_text + + yield self._create_text_chunk(chunk_text, current_messages) + + # Update usage + if usage_dict.get("usage"): + if llm_usage.get("usage"): + self._accumulate_usage(llm_usage, usage_dict["usage"]) + else: + llm_usage["usage"] = usage_dict["usage"] + + # Clean up thought + scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you" + + # Finish model log + yield self._finish_log( + model_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + }, + usage=llm_usage.get("usage"), + ) + + return scratchpad, finish_reason + + def _handle_tool_call( + self, + action: AgentScratchpadUnit.Action, + prompt_messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File]]]: + """Handle tool call and return observation with files.""" + tool_name = action.action_name + tool_args: dict[str, Any] | str = action.action_input + + # Start tool log + tool_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + ) + yield tool_log + + # Find tool instance + tool_instance = self._find_tool_by_name(tool_name) + if not tool_instance: + # Finish tool log with error + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "error": f"Tool {tool_name} not found", + }, + ) + return f"Tool {tool_name} not found", [] + + # Ensure tool_args is a dict + tool_args_dict: dict[str, Any] + if isinstance(tool_args, str): + try: + tool_args_dict = json.loads(tool_args) + except json.JSONDecodeError: + tool_args_dict = {"input": tool_args} + elif not isinstance(tool_args, dict): + tool_args_dict = {"input": str(tool_args)} + else: + tool_args_dict = tool_args + + # Invoke tool using base class method with error handling + try: + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name) + + # Finish tool log + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + + return response_content or "Tool executed successfully", tool_files + except Exception as e: + # Tool invocation failed, yield error log + error_message = str(e) + tool_log.status = AgentLog.LogStatus.ERROR + tool_log.error = error_message + tool_log.data = { + **tool_log.data, + "error": error_message, + } + yield tool_log + + return f"Tool execution failed: {error_message}", [] diff --git a/api/core/agent/patterns/strategy_factory.py b/api/core/agent/patterns/strategy_factory.py new file mode 100644 index 0000000000..ad26075291 --- /dev/null +++ b/api/core/agent/patterns/strategy_factory.py @@ -0,0 +1,107 @@ +"""Strategy factory for creating agent strategies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.agent.entities import AgentEntity, ExecutionContext +from core.file.models import File +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelFeature + +from .base import AgentPattern, ToolInvokeHook +from .function_call import FunctionCallStrategy +from .react import ReActStrategy + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class StrategyFactory: + """Factory for creating agent strategies based on model features.""" + + # Tool calling related features + TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL} + + @staticmethod + def create_strategy( + model_features: list[ModelFeature], + model_instance: ModelInstance, + context: ExecutionContext, + tools: list[Tool], + files: list[File], + max_iterations: int = 10, + workflow_call_depth: int = 0, + agent_strategy: AgentEntity.Strategy | None = None, + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ) -> AgentPattern: + """ + Create an appropriate strategy based on model features. + + Args: + model_features: List of model features/capabilities + model_instance: Model instance to use + context: Execution context containing trace/audit information + tools: Available tools + files: Available files + max_iterations: Maximum iterations for the strategy + workflow_call_depth: Depth of workflow calls + agent_strategy: Optional explicit strategy override + tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke) + instruction: Optional instruction for ReAct strategy + + Returns: + AgentStrategy instance + """ + # If explicit strategy is provided and it's Function Calling, try to use it if supported + if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING: + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + # Fallback to ReAct if FC is requested but not supported + + # If explicit strategy is Chain of Thought (ReAct) + if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Default auto-selection logic + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + # Model supports native function calling + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + else: + # Use ReAct strategy for models without function calling + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index da1e9f19b6..53fa27cca7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,7 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass, field from threading import Thread from typing import Any, Union @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( + ChunkType, MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueAgentLogEvent, @@ -70,13 +72,120 @@ from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models import Account, Conversation, EndUser, Message, MessageFile +from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile from models.enums import CreatorUserRole from models.workflow import Workflow logger = logging.getLogger(__name__) +@dataclass +class StreamEventBuffer: + """ + Buffer for recording stream events in order to reconstruct the generation sequence. + Records the exact order of text chunks, thoughts, and tool calls as they stream. + """ + + # Accumulated reasoning content (each thought block is a separate element) + reasoning_content: list[str] = field(default_factory=list) + # Current reasoning buffer (accumulates until we see a different event type) + _current_reasoning: str = "" + # Tool calls with their details + tool_calls: list[dict] = field(default_factory=list) + # Tool call ID to index mapping for updating results + _tool_call_id_map: dict[str, int] = field(default_factory=dict) + # Sequence of events in stream order + sequence: list[dict] = field(default_factory=list) + # Current position in answer text + _content_position: int = 0 + # Track last event type to detect transitions + _last_event_type: str | None = None + + def _flush_current_reasoning(self) -> None: + """Flush accumulated reasoning to the list and add to sequence.""" + if self._current_reasoning.strip(): + self.reasoning_content.append(self._current_reasoning.strip()) + self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1}) + self._current_reasoning = "" + + def record_text_chunk(self, text: str) -> None: + """Record a text chunk event.""" + if not text: + return + + # Flush any pending reasoning first + if self._last_event_type == "thought": + self._flush_current_reasoning() + + text_len = len(text) + start_pos = self._content_position + + # If last event was also content, extend it; otherwise create new + if self.sequence and self.sequence[-1].get("type") == "content": + self.sequence[-1]["end"] = start_pos + text_len + else: + self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len}) + + self._content_position += text_len + self._last_event_type = "content" + + def record_thought_chunk(self, text: str) -> None: + """Record a thought/reasoning chunk event.""" + if not text: + return + + # Accumulate thought content + self._current_reasoning += text + self._last_event_type = "thought" + + def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None: + """Record a tool call event.""" + if not tool_call_id: + return + + # Flush any pending reasoning first + if self._last_event_type == "thought": + self._flush_current_reasoning() + + # Check if this tool call already exists (we might get multiple chunks) + if tool_call_id in self._tool_call_id_map: + idx = self._tool_call_id_map[tool_call_id] + # Update arguments if provided + if tool_arguments: + self.tool_calls[idx]["arguments"] = tool_arguments + else: + # New tool call + tool_call = { + "id": tool_call_id or "", + "name": tool_name or "", + "arguments": tool_arguments or "", + "result": "", + } + self.tool_calls.append(tool_call) + idx = len(self.tool_calls) - 1 + self._tool_call_id_map[tool_call_id] = idx + self.sequence.append({"type": "tool_call", "index": idx}) + + self._last_event_type = "tool_call" + + def record_tool_result(self, tool_call_id: str, result: str) -> None: + """Record a tool result event (update existing tool call).""" + if not tool_call_id: + return + if tool_call_id in self._tool_call_id_map: + idx = self._tool_call_id_map[tool_call_id] + self.tool_calls[idx]["result"] = result + + def finalize(self) -> None: + """Finalize the buffer, flushing any pending data.""" + if self._last_event_type == "thought": + self._flush_current_reasoning() + + def has_data(self) -> bool: + """Check if there's any meaningful data recorded.""" + return bool(self.reasoning_content or self.tool_calls or self.sequence) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -144,6 +253,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id: str = "" self._draft_var_saver_factory = draft_var_saver_factory self._graph_runtime_state: GraphRuntimeState | None = None + # Stream event buffer for recording generation sequence + self._stream_buffer = StreamEventBuffer() self._seed_graph_runtime_state_from_queue_manager() def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: @@ -383,7 +494,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None, **kwargs, ) -> Generator[StreamResponse, None, None]: - """Handle text chunk events.""" + """Handle text chunk events and record to stream buffer for sequence reconstruction.""" delta_text = event.text if delta_text is None: return @@ -405,9 +516,45 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if tts_publisher and queue_message: tts_publisher.publish(queue_message) - self._task_state.answer += delta_text + tool_call = event.tool_call + tool_result = event.tool_result + tool_payload = tool_call or tool_result + tool_call_id = tool_payload.id if tool_payload and tool_payload.id else "" + tool_name = tool_payload.name if tool_payload and tool_payload.name else "" + tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else "" + tool_files = tool_result.files if tool_result else [] + + # Record stream event based on chunk type + chunk_type = event.chunk_type or ChunkType.TEXT + match chunk_type: + case ChunkType.TEXT: + self._stream_buffer.record_text_chunk(delta_text) + self._task_state.answer += delta_text + case ChunkType.THOUGHT: + # Reasoning should not be part of final answer text + self._stream_buffer.record_thought_chunk(delta_text) + case ChunkType.TOOL_CALL: + self._stream_buffer.record_tool_call( + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + ) + case ChunkType.TOOL_RESULT: + self._stream_buffer.record_tool_result( + tool_call_id=tool_call_id, + result=delta_text, + ) + self._task_state.answer += delta_text + yield self._message_cycle_manager.message_to_stream_response( - answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector + answer=delta_text, + message_id=self._message_id, + from_variable_selector=event.from_variable_selector, + chunk_type=event.chunk_type.value if event.chunk_type else None, + tool_call_id=tool_call_id or None, + tool_name=tool_name or None, + tool_arguments=tool_arguments or None, + tool_files=tool_files, ) def _handle_iteration_start_event( @@ -775,6 +922,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): # If there are assistant files, remove markdown image links from answer answer_text = self._task_state.answer + answer_text = self._strip_think_blocks(answer_text) if self._recorded_files: # Remove markdown image links since we're storing files separately answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip() @@ -826,6 +974,54 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ] session.add_all(message_files) + # Save generation detail (reasoning/tool calls/sequence) from stream buffer + self._save_generation_detail(session=session, message=message) + + @staticmethod + def _strip_think_blocks(text: str) -> str: + """Remove ... blocks (including their content) from text.""" + if not text or "]*>.*?", "", text, flags=re.IGNORECASE | re.DOTALL) + clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() + return clean_text + + def _save_generation_detail(self, *, session: Session, message: Message) -> None: + """ + Save LLM generation detail for Chatflow using stream event buffer. + The buffer records the exact order of events as they streamed, + allowing accurate reconstruction of the generation sequence. + """ + # Finalize the stream buffer to flush any pending data + self._stream_buffer.finalize() + + # Only save if there's meaningful data + if not self._stream_buffer.has_data(): + return + + reasoning_content = self._stream_buffer.reasoning_content + tool_calls = self._stream_buffer.tool_calls + sequence = self._stream_buffer.sequence + + # Check if generation detail already exists for this message + existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first() + + if existing: + existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None + existing.tool_calls = json.dumps(tool_calls) if tool_calls else None + existing.sequence = json.dumps(sequence) if sequence else None + else: + generation_detail = LLMGenerationDetail( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + message_id=message.id, + reasoning_content=json.dumps(reasoning_content) if reasoning_content else None, + tool_calls=json.dumps(tool_calls) if tool_calls else None, + sequence=json.dumps(sequence) if sequence else None, + ) + session.add(generation_detail) + def _seed_graph_runtime_state_from_queue_manager(self) -> None: """Bootstrap the cached runtime state from the queue manager when present.""" candidate = self._base_task_pipeline.queue_manager.graph_runtime_state diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 2760466a3b..f5cf7a2c56 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -3,10 +3,8 @@ from typing import cast from sqlalchemy import select -from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner +from core.agent.agent_app_runner import AgentAppRunner from core.agent.entities import AgentEntity -from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner @@ -14,8 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationError from extensions.ext_database import db @@ -194,22 +191,7 @@ class AgentChatAppRunner(AppRunner): raise ValueError("Message not found") db.session.close() - runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] - # start agent runner - if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: - # check LLM mode - if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT: - runner_cls = CotChatAgentRunner - elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION: - runner_cls = CotCompletionAgentRunner - else: - raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}") - elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: - runner_cls = FunctionCallAgentRunner - else: - raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}") - - runner = runner_cls( + runner = AgentAppRunner( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, conversation=conversation_result, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d30..0f3f9972c3 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -671,7 +671,7 @@ class WorkflowResponseConverter: task_id=task_id, data=AgentLogStreamResponse.Data( node_execution_id=event.node_execution_id, - id=event.id, + message_id=event.id, parent_id=event.parent_id, label=event.label, error=event.error, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 842ad545ad..1a9d09f5e7 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -13,6 +13,7 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, + ChunkType, MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, @@ -483,11 +484,27 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): if delta_text is None: return + tool_call = event.tool_call + tool_result = event.tool_result + tool_payload = tool_call or tool_result + tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None + tool_name = tool_payload.name if tool_payload and tool_payload.name else None + tool_arguments = tool_call.arguments if tool_call else None + tool_files = tool_result.files if tool_result else [] + # only publish tts message at text chunk streaming if tts_publisher and queue_message: tts_publisher.publish(queue_message) - yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector) + yield self._text_chunk_to_stream_response( + text=delta_text, + from_variable_selector=event.from_variable_selector, + chunk_type=event.chunk_type, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + tool_files=tool_files, + ) def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: """Handle agent log events.""" @@ -650,16 +667,35 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): session.add(workflow_app_log) def _text_chunk_to_stream_response( - self, text: str, from_variable_selector: list[str] | None = None + self, + text: str, + from_variable_selector: list[str] | None = None, + chunk_type: ChunkType | None = None, + tool_call_id: str | None = None, + tool_name: str | None = None, + tool_arguments: str | None = None, + tool_files: list[str] | None = None, + tool_error: str | None = None, ) -> TextChunkStreamResponse: """ Handle completed event. :param text: text :return: """ + from core.app.entities.task_entities import ChunkType as ResponseChunkType + response = TextChunkStreamResponse( task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), + data=TextChunkStreamResponse.Data( + text=text, + from_variable_selector=from_variable_selector, + chunk_type=ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + tool_files=tool_files or [], + tool_error=tool_error, + ), ) return response diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 0e125b3538..3b02683764 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -455,12 +455,20 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunStreamChunkEvent): + from core.app.entities.queue_entities import ChunkType as QueueChunkType + + if event.is_final and not event.chunk: + return + self._publish_event( QueueTextChunkEvent( text=event.chunk, from_variable_selector=list(event.selector), in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, + chunk_type=QueueChunkType(event.chunk_type.value), + tool_call=event.tool_call, + tool_result=event.tool_result, ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): diff --git a/api/core/app/entities/llm_generation_entities.py b/api/core/app/entities/llm_generation_entities.py new file mode 100644 index 0000000000..4e278249fe --- /dev/null +++ b/api/core/app/entities/llm_generation_entities.py @@ -0,0 +1,69 @@ +""" +LLM Generation Detail entities. + +Defines the structure for storing and transmitting LLM generation details +including reasoning content, tool calls, and their sequence. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class ContentSegment(BaseModel): + """Represents a content segment in the generation sequence.""" + + type: Literal["content"] = "content" + start: int = Field(..., description="Start position in the text") + end: int = Field(..., description="End position in the text") + + +class ReasoningSegment(BaseModel): + """Represents a reasoning segment in the generation sequence.""" + + type: Literal["reasoning"] = "reasoning" + index: int = Field(..., description="Index into reasoning_content array") + + +class ToolCallSegment(BaseModel): + """Represents a tool call segment in the generation sequence.""" + + type: Literal["tool_call"] = "tool_call" + index: int = Field(..., description="Index into tool_calls array") + + +SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment + + +class ToolCallDetail(BaseModel): + """Represents a tool call with its arguments and result.""" + + id: str = Field(default="", description="Unique identifier for the tool call") + name: str = Field(..., description="Name of the tool") + arguments: str = Field(default="", description="JSON string of tool arguments") + result: str = Field(default="", description="Result from the tool execution") + + +class LLMGenerationDetailData(BaseModel): + """ + Domain model for LLM generation detail. + + Contains the structured data for reasoning content, tool calls, + and their display sequence. + """ + + reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments") + tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details") + sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments") + + def is_empty(self) -> bool: + """Check if there's any meaningful generation detail.""" + return not self.reasoning_content and not self.tool_calls + + def to_response_dict(self) -> dict: + """Convert to dictionary for API response.""" + return { + "reasoning_content": self.reasoning_content, + "tool_calls": [tc.model_dump() for tc in self.tool_calls], + "sequence": [seg.model_dump() for seg in self.sequence], + } diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 77d6bf03b4..31b95ad165 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult from core.workflow.enums import WorkflowNodeExecutionMetadataKey from core.workflow.nodes import NodeType @@ -177,6 +177,15 @@ class QueueLoopCompletedEvent(AppQueueEvent): error: str | None = None +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity @@ -191,6 +200,16 @@ class QueueTextChunkEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" + # Extended fields for Agent/Tool streaming + chunk_type: ChunkType = ChunkType.TEXT + """type of the chunk""" + + # Tool streaming payloads + tool_call: ToolCall | None = None + """structured tool call info""" + tool_result: ToolResult | None = None + """structured tool result info""" + class QueueAgentMessageEvent(AppQueueEvent): """ diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 79a5e657b3..28951021b6 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -113,6 +113,24 @@ class MessageStreamResponse(StreamResponse): answer: str from_variable_selector: list[str] | None = None + # Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import) + chunk_type: str | None = None + """type of the chunk: text, tool_call, tool_result, thought""" + + # Tool call fields (when chunk_type == "tool_call") + tool_call_id: str | None = None + """unique identifier for this tool call""" + tool_name: str | None = None + """name of the tool being called""" + tool_arguments: str | None = None + """accumulated tool arguments JSON""" + + # Tool result fields (when chunk_type == "tool_result") + tool_files: list[str] | None = None + """file IDs produced by tool""" + tool_error: str | None = None + """error message if tool failed""" + class MessageAudioStreamResponse(StreamResponse): """ @@ -582,6 +600,15 @@ class LoopNodeCompletedStreamResponse(StreamResponse): data: Data +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class TextChunkStreamResponse(StreamResponse): """ TextChunkStreamResponse entity @@ -595,6 +622,24 @@ class TextChunkStreamResponse(StreamResponse): text: str from_variable_selector: list[str] | None = None + # Extended fields for Agent/Tool streaming + chunk_type: ChunkType = ChunkType.TEXT + """type of the chunk""" + + # Tool call fields (when chunk_type == TOOL_CALL) + tool_call_id: str | None = None + """unique identifier for this tool call""" + tool_name: str | None = None + """name of the tool being called""" + tool_arguments: str | None = None + """accumulated tool arguments JSON""" + + # Tool result fields (when chunk_type == TOOL_RESULT) + tool_files: list[str] = Field(default_factory=list) + """file IDs produced by tool""" + tool_error: str | None = None + """error message if tool failed""" + event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -743,7 +788,7 @@ class AgentLogStreamResponse(StreamResponse): """ node_execution_id: str - id: str + message_id: str label: str parent_id: str | None = None error: str | None = None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 5bb93fa44a..a1852ffe19 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -1,4 +1,5 @@ import logging +import re import time from collections.abc import Generator from threading import Thread @@ -58,7 +59,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.model import AppMode, Conversation, Message, MessageAgentThought +from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) + _task_state: EasyUITaskState _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] @@ -409,11 +412,136 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) ) + # Save LLM generation detail if there's reasoning_content + self._save_generation_detail(session=session, message=message, llm_result=llm_result) + message_was_created.send( message, application_generate_entity=self._application_generate_entity, ) + def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None: + """ + Save LLM generation detail for Completion/Chat/Agent-Chat applications. + For Agent-Chat, also merges MessageAgentThought records. + """ + import json + + reasoning_list: list[str] = [] + tool_calls_list: list[dict] = [] + sequence: list[dict] = [] + answer = message.answer or "" + + # Check if this is Agent-Chat mode by looking for agent thoughts + agent_thoughts = ( + session.query(MessageAgentThought) + .filter_by(message_id=message.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) + + if agent_thoughts: + # Agent-Chat mode: merge MessageAgentThought records + content_pos = 0 + cleaned_answer_parts: list[str] = [] + for thought in agent_thoughts: + # Add thought/reasoning + if thought.thought: + reasoning_text = thought.thought + if " blocks and clean the final answer + clean_answer, reasoning_content = self._split_reasoning_from_answer(answer) + if reasoning_content: + answer = clean_answer + llm_result.message.content = clean_answer + llm_result.reasoning_content = reasoning_content + message.answer = clean_answer + if reasoning_content: + reasoning_list = [reasoning_content] + # Content comes first, then reasoning + if answer: + sequence.append({"type": "content", "start": 0, "end": len(answer)}) + sequence.append({"type": "reasoning", "index": 0}) + + # Only save if there's meaningful generation detail + if not reasoning_list and not tool_calls_list: + return + + # Check if generation detail already exists + existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first() + + if existing: + existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None + existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None + existing.sequence = json.dumps(sequence) if sequence else None + else: + generation_detail = LLMGenerationDetail( + tenant_id=self._application_generate_entity.app_config.tenant_id, + app_id=self._application_generate_entity.app_config.app_id, + message_id=message.id, + reasoning_content=json.dumps(reasoning_list) if reasoning_list else None, + tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None, + sequence=json.dumps(sequence) if sequence else None, + ) + session.add(generation_detail) + + @classmethod + def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]: + """ + Extract reasoning segments from blocks and return (clean_text, reasoning). + """ + matches = cls._THINK_PATTERN.findall(text) + reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" + + clean_text = cls._THINK_PATTERN.sub("", text) + clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() + + return clean_text, reasoning_content or "" + def _handle_stop(self, event: QueueStopEvent): """ Handle stop. diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 0e7f300cee..f2fa3c064b 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -232,12 +232,25 @@ class MessageCycleManager: answer: str, message_id: str, from_variable_selector: list[str] | None = None, + chunk_type: str | None = None, + tool_call_id: str | None = None, + tool_name: str | None = None, + tool_arguments: str | None = None, + tool_files: list[str] | None = None, + tool_error: str | None = None, event_type: StreamEvent | None = None, ) -> MessageStreamResponse: """ Message to stream response. :param answer: answer :param message_id: message id + :param from_variable_selector: from variable selector + :param chunk_type: type of the chunk (text, function_call, tool_result, thought) + :param tool_call_id: unique identifier for this tool call + :param tool_name: name of the tool being called + :param tool_arguments: accumulated tool arguments JSON + :param tool_files: file IDs produced by tool + :param tool_error: error message if tool failed :return: """ return MessageStreamResponse( @@ -245,6 +258,12 @@ class MessageCycleManager: id=message_id, answer=answer, from_variable_selector=from_variable_selector, + chunk_type=chunk_type, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + tool_files=tool_files, + tool_error=tool_error, event=event_type or StreamEvent.MESSAGE, ) diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index d0279349ca..5249fea8cd 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -5,7 +5,6 @@ from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document @@ -90,6 +89,8 @@ class DatasetIndexToolCallbackHandler: # TODO(-LAN-): Improve type check def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): """Handle return_retriever_resource_info.""" + from core.app.entities.queue_entities import QueueRetrieverResourcesEvent + self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 4436773d25..a45d1d1046 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -29,6 +29,7 @@ from models import ( Account, CreatorUserRole, EndUser, + LLMGenerationDetail, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, ) @@ -457,6 +458,113 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) session.merge(db_model) session.flush() + # Save LLMGenerationDetail for LLM nodes with successful execution + if ( + domain_model.node_type == NodeType.LLM + and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED + and domain_model.outputs is not None + ): + self._save_llm_generation_detail(session, domain_model) + + def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None: + """ + Save LLM generation detail for LLM nodes. + Extracts reasoning_content, tool_calls, and sequence from outputs and metadata. + """ + outputs = execution.outputs or {} + metadata = execution.metadata or {} + + reasoning_list = self._extract_reasoning(outputs) + tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)) + + if not reasoning_list and not tool_calls_list: + return + + sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list) + self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence) + + def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]: + """Extract reasoning_content as a clean list of non-empty strings.""" + reasoning_content = outputs.get("reasoning_content") + if isinstance(reasoning_content, str): + trimmed = reasoning_content.strip() + return [trimmed] if trimmed else [] + if isinstance(reasoning_content, list): + return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()] + return [] + + def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]: + """Extract tool call records from agent logs.""" + if not agent_log or not isinstance(agent_log, list): + return [] + + tool_calls: list[dict[str, str]] = [] + for log in agent_log: + log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {}) + tool_name = log_data.get("tool_name") + if tool_name and str(tool_name).strip(): + tool_calls.append( + { + "id": log_data.get("tool_call_id", ""), + "name": tool_name, + "arguments": json.dumps(log_data.get("tool_args", {})), + "result": str(log_data.get("output", "")), + } + ) + return tool_calls + + def _build_generation_sequence( + self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]] + ) -> list[dict[str, Any]]: + """Build a simple content/reasoning/tool_call sequence.""" + sequence: list[dict[str, Any]] = [] + if text: + sequence.append({"type": "content", "start": 0, "end": len(text)}) + for index in range(len(reasoning_list)): + sequence.append({"type": "reasoning", "index": index}) + for index in range(len(tool_calls_list)): + sequence.append({"type": "tool_call", "index": index}) + return sequence + + def _upsert_generation_detail( + self, + session, + execution: WorkflowNodeExecution, + reasoning_list: list[str], + tool_calls_list: list[dict[str, str]], + sequence: list[dict[str, Any]], + ) -> None: + """Insert or update LLMGenerationDetail with serialized fields.""" + existing = ( + session.query(LLMGenerationDetail) + .filter_by( + workflow_run_id=execution.workflow_execution_id, + node_id=execution.node_id, + ) + .first() + ) + + reasoning_json = json.dumps(reasoning_list) if reasoning_list else None + tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None + sequence_json = json.dumps(sequence) if sequence else None + + if existing: + existing.reasoning_content = reasoning_json + existing.tool_calls = tool_calls_json + existing.sequence = sequence_json + return + + generation_detail = LLMGenerationDetail( + tenant_id=self._tenant_id, + app_id=self._app_id, + workflow_run_id=execution.workflow_execution_id, + node_id=execution.node_id, + reasoning_content=reasoning_json, + tool_calls=tool_calls_json, + sequence=sequence_json, + ) + session.add(generation_detail) + def get_db_models_by_workflow_run( self, workflow_run_id: str, diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 8ca4eabb7a..cdbfd027ee 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from models.model import File +from core.model_runtime.entities.message_entities import PromptMessageTool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ToolEntity, @@ -152,6 +153,60 @@ class Tool(ABC): return parameters + def to_prompt_message_tool(self) -> PromptMessageTool: + message_tool = PromptMessageTool( + name=self.entity.identity.name, + description=self.entity.description.llm if self.entity.description else "", + parameters={ + "type": "object", + "properties": {}, + "required": [], + }, + ) + + parameters = self.get_merged_runtime_parameters() + for parameter in parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + # Determine the description based on parameter type + if parameter.type == ToolParameter.ToolParameterType.FILE: + file_format_desc = " Input the file id with format: [File: file_id]." + else: + file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. " + + message_tool.parameters["properties"][parameter.name] = { + "type": "string", + "description": (parameter.llm_description or "") + file_format_desc, + } + continue + enum = [] + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] if parameter.options else [] + + message_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) + + if len(enum) > 0: + message_tool.parameters["properties"][parameter.name]["enum"] = enum + + if parameter.required: + message_tool.parameters["required"].append(parameter.name) + + return message_tool + def create_image_message( self, image: str, diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..0f3b9a5239 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,11 +1,16 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams +from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", + "ToolCall", + "ToolCallResult", + "ToolResult", + "ToolResultStatus", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/entities/tool_entities.py b/api/core/workflow/entities/tool_entities.py new file mode 100644 index 0000000000..f4833218c7 --- /dev/null +++ b/api/core/workflow/entities/tool_entities.py @@ -0,0 +1,33 @@ +from enum import StrEnum + +from pydantic import BaseModel, Field + +from core.file import File + + +class ToolResultStatus(StrEnum): + SUCCESS = "success" + ERROR = "error" + + +class ToolCall(BaseModel): + id: str | None = Field(default=None, description="Unique identifier for this tool call") + name: str | None = Field(default=None, description="Name of the tool being called") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + + +class ToolResult(BaseModel): + id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to") + name: str | None = Field(default=None, description="Name of the tool") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[str] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") + + +class ToolCallResult(BaseModel): + id: str | None = Field(default=None, description="Identifier for the tool call") + name: str | None = Field(default=None, description="Name of the tool") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[File] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index c08b62a253..5ea7cf1a07 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -247,6 +247,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output DATASOURCE_INFO = "datasource_info" + LLM_CONTENT_SEQUENCE = "llm_content_sequence" + LLM_TRACE = "llm_trace" COMPLETED_REASON = "completed_reason" # completed reason for loop node diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 98e0ea91ef..c5ea94ba80 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -16,7 +16,13 @@ from pydantic import BaseModel, Field from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.graph import Graph -from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent +from core.workflow.graph_events import ( + ChunkType, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ToolCall, + ToolResult, +) from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.runtime import VariablePool @@ -321,11 +327,24 @@ class ResponseStreamCoordinator: selector: Sequence[str], chunk: str, is_final: bool = False, + chunk_type: ChunkType = ChunkType.TEXT, + tool_call: ToolCall | None = None, + tool_result: ToolResult | None = None, ) -> NodeRunStreamChunkEvent: """Create a stream chunk event with consistent structure. For selectors with special prefixes (sys, env, conversation), we use the active response node's information since these are not actual node IDs. + + Args: + node_id: The node ID to attribute the event to + execution_id: The execution ID for this node + selector: The variable selector + chunk: The chunk content + is_final: Whether this is the final chunk + chunk_type: The semantic type of the chunk being streamed + tool_call: Structured data for tool_call chunks + tool_result: Structured data for tool_result chunks """ # Check if this is a special selector that doesn't correspond to a node if selector and selector[0] not in self._graph.nodes and self._active_session: @@ -338,6 +357,9 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, + chunk_type=chunk_type, + tool_call=tool_call, + tool_result=tool_result, ) # Standard case: selector refers to an actual node @@ -349,6 +371,9 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, + chunk_type=chunk_type, + tool_call=tool_call, + tool_result=tool_result, ) def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: @@ -356,6 +381,8 @@ class ResponseStreamCoordinator: Handles both regular node selectors and special system selectors (sys, env, conversation). For special selectors, we attribute the output to the active response node. + + For object-type variables, automatically streams all child fields that have stream events. """ events: list[NodeRunStreamChunkEvent] = [] source_selector_prefix = segment.selector[0] if segment.selector else "" @@ -364,60 +391,81 @@ class ResponseStreamCoordinator: # Determine which node to attribute the output to # For special selectors (sys, env, conversation), use the active response node # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix + active_session = self._active_session + special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes) + output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix execution_id = self._get_or_create_execution_id(output_node_id) - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) + # Check if there's a direct stream for this selector + has_direct_stream = ( + tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams + ) - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True + stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector)) - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, + if stream_targets: + all_complete = True + + for target_selector in stream_targets: + while self._has_unread_stream(target_selector): + if event := self._pop_stream_chunk(target_selector): + events.append( + self._rewrite_stream_event( + event=event, + output_node_id=output_node_id, + execution_id=execution_id, + special_selector=bool(special_selector), + ) + ) + + if not self._is_stream_closed(target_selector): + all_complete = False + + is_complete = all_complete + + # Fallback: check if scalar value exists in variable pool + if not is_complete and not has_direct_stream: + if value := self._variable_pool.get(segment.selector): + # Process scalar value + is_last_segment = bool( + self._active_session + and self._active_session.index == len(self._active_session.template.segments) - 1 ) - ) - is_complete = True + events.append( + self._create_stream_chunk_event( + node_id=output_node_id, + execution_id=execution_id, + selector=segment.selector, + chunk=value.markdown, + is_final=is_last_segment, + ) + ) + is_complete = True return events, is_complete + def _rewrite_stream_event( + self, + event: NodeRunStreamChunkEvent, + output_node_id: str, + execution_id: str, + special_selector: bool, + ) -> NodeRunStreamChunkEvent: + """Rewrite event to attribute to active response node when selector is special.""" + if not special_selector: + return event + + return self._create_stream_chunk_event( + node_id=output_node_id, + execution_id=execution_id, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=event.chunk_type, + tool_call=event.tool_call, + tool_result=event.tool_result, + ) + def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: """Process a text segment. Returns (events, is_complete).""" assert self._active_session is not None @@ -513,6 +561,36 @@ class ResponseStreamCoordinator: # ============= Internal Stream Management Methods ============= + def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]: + """Find all child stream selectors that are descendants of the parent selector. + + For example, if parent_selector is ['llm', 'generation'], this will find: + - ['llm', 'generation', 'content'] + - ['llm', 'generation', 'tool_calls'] + - ['llm', 'generation', 'tool_results'] + - ['llm', 'generation', 'thought'] + + Args: + parent_selector: The parent selector to search for children + + Returns: + List of child selector tuples found in stream buffers or closed streams + """ + parent_key = tuple(parent_selector) + parent_len = len(parent_key) + child_streams: set[tuple[str, ...]] = set() + + # Search in both active buffers and closed streams + all_selectors = set(self._stream_buffers.keys()) | self._closed_streams + + for selector_key in all_selectors: + # Check if this selector is a direct child of the parent + # Direct child means: len(child) == len(parent) + 1 and child starts with parent + if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key: + child_streams.add(selector_key) + + return sorted(child_streams) + def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: """ Append a stream chunk to the internal buffer. diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 7a5edbb331..4ee0ec94d2 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -36,6 +36,7 @@ from .loop import ( # Node events from .node import ( + ChunkType, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunPauseRequestedEvent, @@ -44,10 +45,13 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + ToolCall, + ToolResult, ) __all__ = [ "BaseGraphEvent", + "ChunkType", "GraphEngineEvent", "GraphNodeEventBase", "GraphRunAbortedEvent", @@ -73,4 +77,6 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "ToolCall", + "ToolResult", ] diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index f225798d41..01bc27d3e4 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -1,10 +1,11 @@ from collections.abc import Sequence from datetime import datetime +from enum import StrEnum from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult from core.workflow.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -21,13 +22,37 @@ class NodeRunStartedEvent(GraphNodeEventBase): provider_id: str = "" +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields + """Stream chunk event for workflow node execution.""" + + # Base fields selector: Sequence[str] = Field( ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" ) chunk: str = Field(..., description="the actual chunk content") is_final: bool = Field(default=False, description="indicates if this is the last chunk") + chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") + + # Tool call fields (when chunk_type == TOOL_CALL) + tool_call: ToolCall | None = Field( + default=None, + description="structured payload for tool_call chunks", + ) + + # Tool result fields (when chunk_type == TOOL_RESULT) + tool_result: ToolResult | None = Field( + default=None, + description="structured payload for tool_result chunks", + ) class NodeRunRetrieverResourceEvent(GraphNodeEventBase): diff --git a/api/core/workflow/node_events/__init__.py b/api/core/workflow/node_events/__init__.py index f14a594c85..67263311b9 100644 --- a/api/core/workflow/node_events/__init__.py +++ b/api/core/workflow/node_events/__init__.py @@ -13,16 +13,21 @@ from .loop import ( LoopSucceededEvent, ) from .node import ( + ChunkType, ModelInvokeCompletedEvent, PauseRequestedEvent, RunRetrieverResourceEvent, RunRetryEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) __all__ = [ "AgentLogEvent", + "ChunkType", "IterationFailedEvent", "IterationNextEvent", "IterationStartedEvent", @@ -39,4 +44,7 @@ __all__ = [ "RunRetryEvent", "StreamChunkEvent", "StreamCompletedEvent", + "ThoughtChunkEvent", + "ToolCallChunkEvent", + "ToolResultChunkEvent", ] diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index e4fa52f444..39f09d02a5 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -1,11 +1,13 @@ from collections.abc import Sequence from datetime import datetime +from enum import StrEnum from pydantic import Field from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.workflow.entities import ToolCall, ToolResult from core.workflow.entities.pause_reason import PauseReason from core.workflow.node_events import NodeRunResult @@ -32,13 +34,46 @@ class RunRetryEvent(NodeEventBase): start_at: datetime = Field(..., description="Retry start time") +class ChunkType(StrEnum): + """Stream chunk type for LLM-related events.""" + + TEXT = "text" # Normal text streaming + TOOL_CALL = "tool_call" # Tool call arguments streaming + TOOL_RESULT = "tool_result" # Tool execution result + THOUGHT = "thought" # Agent thinking process (ReAct) + + class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields + """Base stream chunk event - normal text streaming output.""" + selector: Sequence[str] = Field( ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" ) chunk: str = Field(..., description="the actual chunk content") is_final: bool = Field(default=False, description="indicates if this is the last chunk") + chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") + tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks") + tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks") + + +class ToolCallChunkEvent(StreamChunkEvent): + """Tool call streaming event - tool call arguments streaming output.""" + + chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True) + tool_call: ToolCall | None = Field(default=None, description="structured tool call payload") + + +class ToolResultChunkEvent(StreamChunkEvent): + """Tool result event - tool execution result.""" + + chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True) + tool_result: ToolResult | None = Field(default=None, description="structured tool result payload") + + +class ThoughtChunkEvent(StreamChunkEvent): + """Agent thought streaming event - Agent thinking process (ReAct).""" + + chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True) class StreamCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8ebba3659c..302d77d625 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -46,6 +46,9 @@ from core.workflow.node_events import ( RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) from core.workflow.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now @@ -543,6 +546,8 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + return NodeRunStreamChunkEvent( id=self.execution_id, node_id=self._node_id, @@ -550,6 +555,65 @@ class Node(Generic[NodeDataT]): selector=event.selector, chunk=event.chunk, is_final=event.is_final, + chunk_type=ChunkType(event.chunk_type.value), + tool_call=event.tool_call, + tool_result=event.tool_result, + ) + + @_dispatch.register + def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.TOOL_CALL, + tool_call=event.tool_call, + ) + + @_dispatch.register + def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.entities import ToolResult, ToolResultStatus + from core.workflow.graph_events import ChunkType + + tool_result = event.tool_result + status: ToolResultStatus = ( + tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS + ) + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.TOOL_RESULT, + tool_result=ToolResult( + id=tool_result.id if tool_result else None, + name=tool_result.name if tool_result else None, + output=tool_result.output if tool_result else None, + files=tool_result.files if tool_result else [], + status=status, + ), + ) + + @_dispatch.register + def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent: + from core.workflow.graph_events import ChunkType + + return NodeRunStreamChunkEvent( + id=self._node_execution_id, + node_id=self._node_id, + node_type=self.node_type, + selector=event.selector, + chunk=event.chunk, + is_final=event.is_final, + chunk_type=ChunkType.THOUGHT, ) @_dispatch.register diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py index f7bc713f63..edd0d3d581 100644 --- a/api/core/workflow/nodes/llm/__init__.py +++ b/api/core/workflow/nodes/llm/__init__.py @@ -3,6 +3,7 @@ from .entities import ( LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + ToolMetadata, VisionConfig, ) from .node import LLMNode @@ -13,5 +14,6 @@ __all__ = [ "LLMNodeCompletionModelPromptTemplate", "LLMNodeData", "ModelConfig", + "ToolMetadata", "VisionConfig", ] diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index fe6f2290aa..c1938fb5e3 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,10 +1,17 @@ +import re from collections.abc import Mapping, Sequence from typing import Any, Literal -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator +from core.agent.entities import AgentLog, AgentResult +from core.file import File from core.model_runtime.entities import ImagePromptMessageContent, LLMMode +from core.model_runtime.entities.llm_entities import LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.entities import ToolCallResult +from core.workflow.node_events import AgentLogEvent from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector @@ -58,6 +65,235 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): jinja2_text: str | None = None +class ToolMetadata(BaseModel): + """ + Tool metadata for LLM node with tool support. + + Defines the essential fields needed for tool configuration, + particularly the 'type' field to identify tool provider type. + """ + + # Core fields + enabled: bool = True + type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow") + provider_name: str = Field(..., description="Tool provider name/identifier") + tool_name: str = Field(..., description="Tool name") + + # Optional fields + plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools") + credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication") + + # Configuration fields + parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters") + settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration") + extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description") + + +class LLMTraceSegment(BaseModel): + """ + Streaming trace segment for LLM tool-enabled runs. + + Order is preserved for replay. Tool calls are single entries containing both + arguments and results. + """ + + type: Literal["thought", "content", "tool_call"] + + # Common optional fields + text: str | None = Field(None, description="Text chunk for thought/content") + + # Tool call fields (combined start + result) + tool_call: ToolCallResult | None = Field( + default=None, + description="Combined tool call arguments and result for this segment", + ) + + +class LLMGenerationData(BaseModel): + """Generation data from LLM invocation with tools. + + For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2 + - reasoning_contents: [thought1, thought2, ...] - one element per turn + - tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results + """ + + text: str = Field(..., description="Accumulated text content from all turns") + reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn") + tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results") + sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering") + usage: LLMUsage = Field(..., description="LLM usage statistics") + finish_reason: str | None = Field(None, description="Finish reason from LLM") + files: list[File] = Field(default_factory=list, description="Generated files") + trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order") + + +class ThinkTagStreamParser: + """Lightweight state machine to split streaming chunks by tags.""" + + _START_PATTERN = re.compile(r"]*)?>", re.IGNORECASE) + _END_PATTERN = re.compile(r"", re.IGNORECASE) + _START_PREFIX = " int: + """Return length of the longest suffix of `text` that is a prefix of `prefix`.""" + max_len = min(len(text), len(prefix) - 1) + for i in range(max_len, 0, -1): + if text[-i:].lower() == prefix[:i].lower(): + return i + return 0 + + def process(self, chunk: str) -> list[tuple[str, str]]: + """ + Split incoming chunk into ('thought' | 'text', content) tuples. + Content excludes the tags themselves and handles split tags across chunks. + """ + parts: list[tuple[str, str]] = [] + self._buffer += chunk + + while self._buffer: + if self._in_think: + end_match = self._END_PATTERN.search(self._buffer) + if end_match: + thought_text = self._buffer[: end_match.start()] + if thought_text: + parts.append(("thought", thought_text)) + self._buffer = self._buffer[end_match.end() :] + self._in_think = False + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("thought", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + start_match = self._START_PATTERN.search(self._buffer) + if start_match: + prefix = self._buffer[: start_match.start()] + if prefix: + parts.append(("text", prefix)) + self._buffer = self._buffer[start_match.end() :] + self._in_think = True + continue + + hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX) + emit = self._buffer[: len(self._buffer) - hold_len] + if emit: + parts.append(("text", emit)) + self._buffer = self._buffer[-hold_len:] if hold_len > 0 else "" + break + + cleaned_parts: list[tuple[str, str]] = [] + for kind, content in parts: + # Extra safeguard: strip any stray tags that slipped through. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + if content: + cleaned_parts.append((kind, content)) + + return cleaned_parts + + def flush(self) -> list[tuple[str, str]]: + """Flush remaining buffer when the stream ends.""" + if not self._buffer: + return [] + kind = "thought" if self._in_think else "text" + content = self._buffer + # Drop dangling partial tags instead of emitting them + if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX): + content = "" + self._buffer = "" + if not content: + return [] + # Strip any complete tags that might still be present. + content = self._START_PATTERN.sub("", content) + content = self._END_PATTERN.sub("", content) + return [(kind, content)] if content else [] + + +class StreamBuffers(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser) + pending_thought: list[str] = Field(default_factory=list) + pending_content: list[str] = Field(default_factory=list) + current_turn_reasoning: list[str] = Field(default_factory=list) + reasoning_per_turn: list[str] = Field(default_factory=list) + + +class TraceState(BaseModel): + trace_segments: list[LLMTraceSegment] = Field(default_factory=list) + tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict) + tool_call_index_map: dict[str, int] = Field(default_factory=dict) + + +class AggregatedResult(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + text: str = "" + files: list[File] = Field(default_factory=list) + usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) + finish_reason: str | None = None + + +class AgentContext(BaseModel): + agent_logs: list[AgentLogEvent] = Field(default_factory=list) + agent_result: AgentResult | None = None + + +class ToolOutputState(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + stream: StreamBuffers = Field(default_factory=StreamBuffers) + trace: TraceState = Field(default_factory=TraceState) + aggregate: AggregatedResult = Field(default_factory=AggregatedResult) + agent: AgentContext = Field(default_factory=AgentContext) + + +class ToolLogPayload(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + tool_name: str = "" + tool_call_id: str = "" + tool_args: dict[str, Any] = Field(default_factory=dict) + tool_output: Any = None + tool_error: Any = None + files: list[Any] = Field(default_factory=list) + meta: dict[str, Any] = Field(default_factory=dict) + + @classmethod + def from_log(cls, log: AgentLog) -> "ToolLogPayload": + data = log.data or {} + return cls( + tool_name=data.get("tool_name", ""), + tool_call_id=data.get("tool_call_id", ""), + tool_args=data.get("tool_args") or {}, + tool_output=data.get("output"), + tool_error=data.get("error"), + files=data.get("files") or [], + meta=data.get("meta") or {}, + ) + + @classmethod + def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload": + return cls( + tool_name=data.get("tool_name", ""), + tool_call_id=data.get("tool_call_id", ""), + tool_args=data.get("tool_args") or {}, + tool_output=data.get("output"), + tool_error=data.get("error"), + files=data.get("files") or [], + meta=data.get("meta") or {}, + ) + + class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate @@ -86,6 +322,10 @@ class LLMNodeData(BaseNodeData): ), ) + # Tool support + tools: Sequence[ToolMetadata] = Field(default_factory=list) + max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node") + @field_validator("prompt_config", mode="before") @classmethod def convert_none_prompt_config(cls, v: Any): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 04e2802191..6be59b6ead 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import select +from core.agent.entities import AgentLog, AgentResult, AgentToolEntity, ExecutionContext +from core.agent.patterns import StrategyFactory from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage @@ -46,7 +48,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.__base.tool import Tool from core.tools.signature import sign_upload_file +from core.tools.tool_manager import ToolManager from core.variables import ( ArrayFileSegment, ArraySegment, @@ -56,7 +60,8 @@ from core.variables import ( StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams +from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus +from core.workflow.entities.tool_entities import ToolCallResult from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -64,12 +69,16 @@ from core.workflow.enums import ( WorkflowNodeExecutionStatus, ) from core.workflow.node_events import ( + AgentLogEvent, ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, ) from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node @@ -81,10 +90,19 @@ from models.model import UploadFile from . import llm_utils from .entities import ( + AgentContext, + AggregatedResult, + LLMGenerationData, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, + LLMTraceSegment, ModelConfig, + StreamBuffers, + ThinkTagStreamParser, + ToolLogPayload, + ToolOutputState, + TraceState, ) from .exc import ( InvalidContextStructureError, @@ -149,11 +167,11 @@ class LLMNode(Node[LLMNodeData]): def _run(self) -> Generator: node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} - result_text = "" clean_text = "" usage = LLMUsage.empty_usage() finish_reason = None - reasoning_content = None + reasoning_content = "" # Initialize as empty string for consistency + clean_text = "" # Initialize clean_text to avoid UnboundLocalError variable_pool = self.graph_runtime_state.variable_pool try: @@ -234,55 +252,58 @@ class LLMNode(Node[LLMNodeData]): context_files=context_files, ) - # handle invoke result - generator = LLMNode.invoke_llm( - node_data_model=self.node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.user_id, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) - + # Variables for outputs + generation_data: LLMGenerationData | None = None structured_output: LLMStructuredOutput | None = None - for event in generator: - if isinstance(event, StreamChunkEvent): - yield event - elif isinstance(event, ModelInvokeCompletedEvent): - # Raw text - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - reasoning_content = event.reasoning_content or "" + # Check if tools are configured + if self.tool_call_enabled: + # Use tool-enabled invocation (Agent V2 style) + generator = self._invoke_llm_with_tools( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + files=files, + variable_pool=variable_pool, + node_inputs=node_inputs, + process_data=process_data, + ) + else: + # Use traditional LLM invocation + generator = LLMNode.invoke_llm( + node_data_model=self._node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self._node_id, + node_type=self.node_type, + reasoning_format=self._node_data.reasoning_format, + ) - # For downstream nodes, determine clean text based on reasoning_format - if self.node_data.reasoning_format == "tagged": - # Keep tags for backward compatibility - clean_text = result_text - else: - # Extract clean text from tags - clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) + ( + clean_text, + reasoning_content, + generation_reasoning_content, + generation_clean_content, + usage, + finish_reason, + structured_output, + generation_data, + ) = yield from self._stream_llm_events(generator, model_instance=model_instance) - # Process structured output if available from the event. - structured_output = ( - LLMStructuredOutput(structured_output=event.structured_output) - if event.structured_output - else None - ) - - # deduct quota - llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - break - elif isinstance(event, LLMStructuredOutput): - structured_output = event + # Extract variables from generation_data if available + if generation_data: + clean_text = generation_data.text + reasoning_content = "" + usage = generation_data.usage + finish_reason = generation_data.finish_reason + # Unified process_data building process_data = { "model_mode": model_config.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( @@ -293,24 +314,88 @@ class LLMNode(Node[LLMNodeData]): "model_provider": model_config.provider, "model_name": model_config.model, } + if self.tool_call_enabled and self._node_data.tools: + process_data["tools"] = [ + { + "type": tool.type.value if hasattr(tool.type, "value") else tool.type, + "provider_name": tool.provider_name, + "tool_name": tool.tool_name, + } + for tool in self._node_data.tools + if tool.enabled + ] + # Unified outputs building outputs = { "text": clean_text, "reasoning_content": reasoning_content, "usage": jsonable_encoder(usage), "finish_reason": finish_reason, } + + # Build generation field + if generation_data: + # Use generation_data from tool invocation (supports multi-turn) + generation = { + "content": generation_data.text, + "reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...] + "tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls], + "sequence": generation_data.sequence, + } + files_to_output = generation_data.files + else: + # Traditional LLM invocation + generation_reasoning = generation_reasoning_content or reasoning_content + generation_content = generation_clean_content or clean_text + sequence: list[dict[str, Any]] = [] + if generation_reasoning: + sequence = [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len(generation_content)}, + ] + generation = { + "content": generation_content, + "reasoning_content": [generation_reasoning] if generation_reasoning else [], + "tool_calls": [], + "sequence": sequence, + } + files_to_output = self._file_outputs + + outputs["generation"] = generation + if files_to_output: + outputs["files"] = ArrayFileSegment(value=files_to_output) if structured_output: outputs["structured_output"] = structured_output.structured_output - if self._file_outputs: - outputs["files"] = ArrayFileSegment(value=self._file_outputs) # Send final chunk event to indicate streaming is complete - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) + if not self.tool_call_enabled: + # For tool calls, final events are already sent in _process_tool_outputs + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk="", + is_final=True, + ) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=True, + ) + + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, + } + + if generation_data and generation_data.trace: + metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [ + segment.model_dump() for segment in generation_data.trace + ] yield StreamCompletedEvent( node_run_result=NodeRunResult( @@ -318,11 +403,7 @@ class LLMNode(Node[LLMNodeData]): inputs=node_inputs, process_data=process_data, outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, + metadata=metadata, llm_usage=usage, ) ) @@ -444,6 +525,8 @@ class LLMNode(Node[LLMNodeData]): usage = LLMUsage.empty_usage() finish_reason = None full_text_buffer = io.StringIO() + think_parser = ThinkTagStreamParser() + reasoning_chunks: list[str] = [] # Initialize streaming metrics tracking start_time = request_start_time if request_start_time is not None else time.perf_counter() @@ -472,12 +555,32 @@ class LLMNode(Node[LLMNodeData]): has_content = True full_text_buffer.write(text_part) + # Text output: always forward raw chunk (keep tags intact) yield StreamChunkEvent( selector=[node_id, "text"], chunk=text_part, is_final=False, ) + # Generation output: split out thoughts, forward only non-thought content chunks + for kind, segment in think_parser.process(text_part): + if not segment: + continue + + if kind == "thought": + reasoning_chunks.append(segment) + yield ThoughtChunkEvent( + selector=[node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + else: + yield StreamChunkEvent( + selector=[node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + # Update the whole metadata if not model and result.model: model = result.model @@ -492,16 +595,35 @@ class LLMNode(Node[LLMNodeData]): except OutputParserError as e: raise LLMNodeError(f"Failed to parse structured output: {e}") + for kind, segment in think_parser.flush(): + if not segment: + continue + if kind == "thought": + reasoning_chunks.append(segment) + yield ThoughtChunkEvent( + selector=[node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + else: + yield StreamChunkEvent( + selector=[node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + # Extract reasoning content from tags in the main text full_text = full_text_buffer.getvalue() if reasoning_format == "tagged": # Keep tags in text for backward compatibility clean_text = full_text - reasoning_content = "" + reasoning_content = "".join(reasoning_chunks) else: # Extract clean text and reasoning from tags clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) + if reasoning_chunks and not reasoning_content: + reasoning_content = "".join(reasoning_chunks) # Calculate streaming metrics end_time = time.perf_counter() @@ -1266,6 +1388,635 @@ class LLMNode(Node[LLMNodeData]): def retry(self) -> bool: return self.node_data.retry_config.retry_enabled + @property + def tool_call_enabled(self) -> bool: + return ( + self.node_data.tools is not None + and len(self.node_data.tools) > 0 + and all(tool.enabled for tool in self.node_data.tools) + ) + + def _stream_llm_events( + self, + generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None], + *, + model_instance: ModelInstance, + ) -> Generator[ + NodeEventBase, + None, + tuple[ + str, + str, + str, + str, + LLMUsage, + str | None, + LLMStructuredOutput | None, + LLMGenerationData | None, + ], + ]: + """ + Stream events and capture generator return value in one place. + Uses generator delegation so _run stays concise while still emitting events. + """ + clean_text = "" + reasoning_content = "" + generation_reasoning_content = "" + generation_clean_content = "" + usage = LLMUsage.empty_usage() + finish_reason: str | None = None + structured_output: LLMStructuredOutput | None = None + generation_data: LLMGenerationData | None = None + completed = False + + while True: + try: + event = next(generator) + except StopIteration as exc: + if isinstance(exc.value, LLMGenerationData): + generation_data = exc.value + break + + if completed: + # After completion we still drain to reach StopIteration.value + continue + + match event: + case StreamChunkEvent() | ThoughtChunkEvent(): + yield event + + case ModelInvokeCompletedEvent( + text=text, + usage=usage_event, + finish_reason=finish_reason_event, + reasoning_content=reasoning_event, + structured_output=structured_raw, + ): + clean_text = text + usage = usage_event + finish_reason = finish_reason_event + reasoning_content = reasoning_event or "" + generation_reasoning_content = reasoning_content + generation_clean_content = clean_text + + if self.node_data.reasoning_format == "tagged": + # Keep tagged text for output; also extract reasoning for generation field + generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning( + clean_text, reasoning_format="separated" + ) + else: + clean_text, generation_reasoning_content = LLMNode._split_reasoning( + clean_text, self.node_data.reasoning_format + ) + generation_clean_content = clean_text + + structured_output = ( + LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None + ) + + llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + completed = True + + case LLMStructuredOutput(): + structured_output = event + + case _: + continue + + return ( + clean_text, + reasoning_content, + generation_reasoning_content, + generation_clean_content, + usage, + finish_reason, + structured_output, + generation_data, + ) + + def _invoke_llm_with_tools( + self, + model_instance: ModelInstance, + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + files: Sequence["File"], + variable_pool: VariablePool, + node_inputs: dict[str, Any], + process_data: dict[str, Any], + ) -> Generator[NodeEventBase, None, LLMGenerationData]: + """Invoke LLM with tools support (from Agent V2). + + Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files + """ + # Get model features to determine strategy + model_features = self._get_model_features(model_instance) + + # Prepare tool instances + tool_instances = self._prepare_tool_instances(variable_pool) + + # Prepare prompt files (files that come from prompt variables, not vision files) + prompt_files = self._extract_prompt_files(variable_pool) + + # Use factory to create appropriate strategy + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=model_instance, + tools=tool_instances, + files=prompt_files, + max_iterations=self._node_data.max_iterations or 10, + context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), + ) + + # Run strategy + outputs = strategy.run( + prompt_messages=list(prompt_messages), + model_parameters=self._node_data.model.completion_params, + stop=list(stop or []), + stream=True, + ) + + # Process outputs and return generation result + result = yield from self._process_tool_outputs(outputs) + return result + + def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: + """Get model schema to determine features.""" + try: + model_type_instance = model_instance.model_type_instance + model_schema = model_type_instance.get_model_schema( + model_instance.model, + model_instance.credentials, + ) + return model_schema.features if model_schema and model_schema.features else [] + except Exception: + logger.warning("Failed to get model schema, assuming no special features") + return [] + + def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]: + """Prepare tool instances from configuration.""" + tool_instances = [] + + if self._node_data.tools: + for tool in self._node_data.tools: + try: + # Process settings to extract the correct structure + processed_settings = {} + for key, value in tool.settings.items(): + if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict): + # Extract the nested value if it has the ToolInput structure + if "type" in value["value"] and "value" in value["value"]: + processed_settings[key] = value["value"] + else: + processed_settings[key] = value + else: + processed_settings[key] = value + + # Merge parameters with processed settings (similar to Agent Node logic) + merged_parameters = {**tool.parameters, **processed_settings} + + # Create AgentToolEntity from ToolMetadata + agent_tool = AgentToolEntity( + provider_id=tool.provider_name, + provider_type=tool.type, + tool_name=tool.tool_name, + tool_parameters=merged_parameters, + plugin_unique_identifier=tool.plugin_unique_identifier, + credential_id=tool.credential_id, + ) + + # Get tool runtime from ToolManager + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=self.tenant_id, + app_id=self.app_id, + agent_tool=agent_tool, + invoke_from=self.invoke_from, + variable_pool=variable_pool, + ) + + # Apply custom description from extra field if available + if tool.extra.get("description") and tool_runtime.entity.description: + tool_runtime.entity.description.llm = ( + tool.extra.get("description") or tool_runtime.entity.description.llm + ) + + tool_instances.append(tool_runtime) + except Exception as e: + logger.warning("Failed to load tool %s: %s", tool, str(e)) + continue + + return tool_instances + + def _extract_prompt_files(self, variable_pool: VariablePool) -> list["File"]: + """Extract files from prompt template variables.""" + from core.variables import ArrayFileVariable, FileVariable + + files: list[File] = [] + + # Extract variables from prompt template + if isinstance(self._node_data.prompt_template, list): + for message in self._node_data.prompt_template: + if message.text: + parser = VariableTemplateParser(message.text) + variable_selectors = parser.extract_variable_selectors() + + for variable_selector in variable_selectors: + variable = variable_pool.get(variable_selector.value_selector) + if isinstance(variable, FileVariable) and variable.value: + files.append(variable.value) + elif isinstance(variable, ArrayFileVariable) and variable.value: + files.extend(variable.value) + + return files + + @staticmethod + def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]: + """Convert ToolCallResult into JSON-friendly dict.""" + + def _file_to_ref(file: File) -> str | None: + # Align with streamed tool result events which carry file IDs + return file.id or file.related_id + + files = [] + for file in tool_call.files or []: + ref = _file_to_ref(file) + if ref: + files.append(ref) + + return { + "id": tool_call.id, + "name": tool_call.name, + "arguments": tool_call.arguments, + "output": tool_call.output, + "files": files, + "status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status, + } + + def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None: + if not buffers.pending_thought: + return + trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought))) + buffers.pending_thought.clear() + + def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None: + if not buffers.pending_content: + return + trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content))) + buffers.pending_content.clear() + + def _handle_agent_log_output( + self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext + ) -> Generator[NodeEventBase, None, None]: + payload = ToolLogPayload.from_log(output) + agent_log_event = AgentLogEvent( + message_id=output.id, + label=output.label, + node_execution_id=self.id, + parent_id=output.parent_id, + error=output.error, + status=output.status.value, + data=output.data, + metadata={k.value: v for k, v in output.metadata.items()}, + node_id=self._node_id, + ) + for log in agent_context.agent_logs: + if log.message_id == agent_log_event.message_id: + log.data = agent_log_event.data + log.status = agent_log_event.status + log.error = agent_log_event.error + log.label = agent_log_event.label + log.metadata = agent_log_event.metadata + break + else: + agent_context.agent_logs.append(agent_log_event) + + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START: + tool_name = payload.tool_name + tool_call_id = payload.tool_call_id + tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else "" + + if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: + trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) + + self._flush_thought_segment(buffers, trace_state) + self._flush_content_segment(buffers, trace_state) + + tool_call_segment = LLMTraceSegment( + type="tool_call", + text=None, + tool_call=ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + ), + ) + trace_state.trace_segments.append(tool_call_segment) + if tool_call_id: + trace_state.tool_trace_map[tool_call_id] = tool_call_segment + + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk=tool_arguments, + tool_call=ToolCall( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + ), + is_final=False, + ) + + if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START: + tool_name = payload.tool_name + tool_output = payload.tool_output + tool_call_id = payload.tool_call_id + tool_files = payload.files if isinstance(payload.files, list) else [] + tool_error = payload.tool_error + + if tool_call_id and tool_call_id not in trace_state.tool_call_index_map: + trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map) + + self._flush_thought_segment(buffers, trace_state) + self._flush_content_segment(buffers, trace_state) + + if output.status == AgentLog.LogStatus.ERROR: + tool_error = output.error or payload.tool_error + if not tool_error and payload.meta: + tool_error = payload.meta.get("error") + else: + if payload.meta: + meta_error = payload.meta.get("error") + if meta_error: + tool_error = meta_error + + existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id) + tool_call_segment = existing_tool_segment or LLMTraceSegment( + type="tool_call", + text=None, + tool_call=ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=None, + ), + ) + if existing_tool_segment is None: + trace_state.trace_segments.append(tool_call_segment) + if tool_call_id: + trace_state.tool_trace_map[tool_call_id] = tool_call_segment + + if tool_call_segment.tool_call is None: + tool_call_segment.tool_call = ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=None, + ) + tool_call_segment.tool_call.output = ( + str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None + ) + tool_call_segment.tool_call.files = [] + tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS + + result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None + + yield ToolResultChunkEvent( + selector=[self._node_id, "generation", "tool_results"], + chunk=result_output or "", + tool_result=ToolResult( + id=tool_call_id, + name=tool_name, + output=result_output, + files=tool_files, + status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS, + ), + is_final=False, + ) + + if buffers.current_turn_reasoning: + buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) + buffers.current_turn_reasoning.clear() + + def _handle_llm_chunk_output( + self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult + ) -> Generator[NodeEventBase, None, None]: + message = output.delta.message + + if message and message.content: + chunk_text = message.content + if isinstance(chunk_text, list): + chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text) + else: + chunk_text = str(chunk_text) + + for kind, segment in buffers.think_parser.process(chunk_text): + if not segment: + continue + + if kind == "thought": + self._flush_content_segment(buffers, trace_state) + buffers.current_turn_reasoning.append(segment) + buffers.pending_thought.append(segment) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + else: + self._flush_thought_segment(buffers, trace_state) + aggregate.text += segment + buffers.pending_content.append(segment) + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=segment, + is_final=False, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + + if output.delta.usage: + self._accumulate_usage(aggregate.usage, output.delta.usage) + + if output.delta.finish_reason: + aggregate.finish_reason = output.delta.finish_reason + + def _flush_remaining_stream( + self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult + ) -> Generator[NodeEventBase, None, None]: + for kind, segment in buffers.think_parser.flush(): + if not segment: + continue + if kind == "thought": + self._flush_content_segment(buffers, trace_state) + buffers.current_turn_reasoning.append(segment) + buffers.pending_thought.append(segment) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk=segment, + is_final=False, + ) + else: + self._flush_thought_segment(buffers, trace_state) + aggregate.text += segment + buffers.pending_content.append(segment) + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=segment, + is_final=False, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk=segment, + is_final=False, + ) + + if buffers.current_turn_reasoning: + buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning)) + + self._flush_thought_segment(buffers, trace_state) + self._flush_content_segment(buffers, trace_state) + + def _close_streams(self) -> Generator[NodeEventBase, None, None]: + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + yield StreamChunkEvent( + selector=[self._node_id, "generation", "content"], + chunk="", + is_final=True, + ) + yield ThoughtChunkEvent( + selector=[self._node_id, "generation", "thought"], + chunk="", + is_final=True, + ) + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk="", + tool_call=ToolCall( + id="", + name="", + arguments="", + ), + is_final=True, + ) + yield ToolResultChunkEvent( + selector=[self._node_id, "generation", "tool_results"], + chunk="", + tool_result=ToolResult( + id="", + name="", + output="", + files=[], + status=ToolResultStatus.SUCCESS, + ), + is_final=True, + ) + + def _build_generation_data( + self, + trace_state: TraceState, + agent_context: AgentContext, + aggregate: AggregatedResult, + buffers: StreamBuffers, + ) -> LLMGenerationData: + sequence: list[dict[str, Any]] = [] + reasoning_index = 0 + content_position = 0 + tool_call_seen_index: dict[str, int] = {} + for trace_segment in trace_state.trace_segments: + if trace_segment.type == "thought": + sequence.append({"type": "reasoning", "index": reasoning_index}) + reasoning_index += 1 + elif trace_segment.type == "content": + segment_text = trace_segment.text or "" + start = content_position + end = start + len(segment_text) + sequence.append({"type": "content", "start": start, "end": end}) + content_position = end + elif trace_segment.type == "tool_call": + tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else "" + if tool_id not in tool_call_seen_index: + tool_call_seen_index[tool_id] = len(tool_call_seen_index) + sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]}) + + tool_calls_for_generation: list[ToolCallResult] = [] + for log in agent_context.agent_logs: + payload = ToolLogPayload.from_mapping(log.data or {}) + tool_call_id = payload.tool_call_id + if not tool_call_id or log.status == AgentLog.LogStatus.START.value: + continue + + tool_args = payload.tool_args + log_error = payload.tool_error + log_output = payload.tool_output + result_text = log_output or log_error or "" + status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS + tool_calls_for_generation.append( + ToolCallResult( + id=tool_call_id, + name=payload.tool_name, + arguments=json.dumps(tool_args) if tool_args else "", + output=result_text, + status=status, + ) + ) + + tool_calls_for_generation.sort( + key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map)) + ) + + return LLMGenerationData( + text=aggregate.text, + reasoning_contents=buffers.reasoning_per_turn, + tool_calls=tool_calls_for_generation, + sequence=sequence, + usage=aggregate.usage, + finish_reason=aggregate.finish_reason, + files=aggregate.files, + trace=trace_state.trace_segments, + ) + + def _process_tool_outputs( + self, + outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult], + ) -> Generator[NodeEventBase, None, LLMGenerationData]: + """Process strategy outputs and convert to node events.""" + state = ToolOutputState() + + try: + for output in outputs: + if isinstance(output, AgentLog): + yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent) + else: + yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate) + except StopIteration as exception: + if isinstance(getattr(exception, "value", None), AgentResult): + state.agent.agent_result = exception.value + + if state.agent.agent_result: + state.aggregate.text = state.agent.agent_result.text or state.aggregate.text + state.aggregate.files = state.agent.agent_result.files + if state.agent.agent_result.usage: + state.aggregate.usage = state.agent.agent_result.usage + if state.agent.agent_result.finish_reason: + state.aggregate.finish_reason = state.agent.agent_result.finish_reason + + yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate) + yield from self._close_streams() + + return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream) + + def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None: + """Accumulate LLM usage statistics.""" + total_usage.prompt_tokens += delta_usage.prompt_tokens + total_usage.completion_tokens += delta_usage.completion_tokens + total_usage.total_tokens += delta_usage.total_tokens + total_usage.prompt_price += delta_usage.prompt_price + total_usage.completion_price += delta_usage.completion_price + total_usage.total_price += delta_usage.total_price + def _combine_message_content_with_role( *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index ecc267cf38..d5b2574edc 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -89,6 +89,7 @@ message_detail_fields = { "status": fields.String, "error": fields.String, "parent_message_id": fields.String, + "generation_detail": fields.Raw, } feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index a419da2e18..8b9bcac76f 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -68,6 +68,7 @@ message_fields = { "message_files": fields.List(fields.Nested(message_file_fields)), "status": fields.String, "error": fields.String, + "generation_detail": fields.Raw, } message_infinite_scroll_pagination_fields = { diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 821ce62ecc..7b878e05c8 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -81,6 +81,7 @@ workflow_run_detail_fields = { "inputs": fields.Raw(attribute="inputs_dict"), "status": fields.String, "outputs": fields.Raw(attribute="outputs_dict"), + "outputs_as_generation": fields.Boolean, "error": fields.String, "elapsed_time": fields.Float, "total_tokens": fields.Integer, @@ -129,6 +130,7 @@ workflow_run_node_execution_fields = { "inputs_truncated": fields.Boolean, "outputs_truncated": fields.Boolean, "process_data_truncated": fields.Boolean, + "generation_detail": fields.Raw, } workflow_run_node_execution_list_fields = { diff --git a/api/migrations/versions/2025_12_17_1617-85c8b4a64f53_add_llm_generation_detail_table.py b/api/migrations/versions/2025_12_17_1617-85c8b4a64f53_add_llm_generation_detail_table.py new file mode 100644 index 0000000000..60786a720c --- /dev/null +++ b/api/migrations/versions/2025_12_17_1617-85c8b4a64f53_add_llm_generation_detail_table.py @@ -0,0 +1,46 @@ +"""add llm generation detail table. + +Revision ID: 85c8b4a64f53 +Revises: 7bb281b7a422 +Create Date: 2025-12-10 16:17:46.597669 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '85c8b4a64f53' +down_revision = '03ea244985ce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('llm_generation_details', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('message_id', models.types.StringUUID(), nullable=True), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True), + sa.Column('node_id', sa.String(length=255), nullable=True), + sa.Column('reasoning_content', models.types.LongText(), nullable=True), + sa.Column('tool_calls', models.types.LongText(), nullable=True), + sa.Column('sequence', models.types.LongText(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.CheckConstraint('(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL) OR (message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)', name=op.f('llm_generation_details_ck_llm_generation_detail_assoc_mode_check')), + sa.PrimaryKeyConstraint('id', name='llm_generation_detail_pkey'), + sa.UniqueConstraint('message_id', name=op.f('llm_generation_details_message_id_key')) + ) + with op.batch_alter_table('llm_generation_details', schema=None) as batch_op: + batch_op.create_index('idx_llm_generation_detail_message', ['message_id'], unique=False) + batch_op.create_index('idx_llm_generation_detail_workflow', ['workflow_run_id', 'node_id'], unique=False) + + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('llm_generation_details') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index e23de832dc..7b81cea415 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -49,6 +49,7 @@ from .model import ( EndUser, IconType, InstalledApp, + LLMGenerationDetail, Message, MessageAgentThought, MessageAnnotation, @@ -155,6 +156,7 @@ __all__ = [ "IconType", "InstalledApp", "InvitationCode", + "LLMGenerationDetail", "LoadBalancingModelConfig", "Message", "MessageAgentThought", diff --git a/api/models/model.py b/api/models/model.py index 44bcabe96f..1bfcb98542 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -31,6 +31,8 @@ from .provider_ids import GenericProviderID from .types import LongText, StringUUID if TYPE_CHECKING: + from core.app.entities.llm_generation_entities import LLMGenerationDetailData + from .workflow import Workflow @@ -1201,6 +1203,18 @@ class Message(Base): .all() ) + # FIXME (Novice) -- It's easy to cause N+1 query problem here. + @property + def generation_detail(self) -> dict[str, Any] | None: + """ + Get LLM generation detail for this message. + Returns the detail as a dictionary or None if not found. + """ + detail = db.session.query(LLMGenerationDetail).filter_by(message_id=self.id).first() + if detail: + return detail.to_dict() + return None + @property def retriever_resources(self) -> Any: return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @@ -2091,3 +2105,87 @@ class TenantCreditPool(Base): def has_sufficient_credits(self, required_credits: int) -> bool: return self.remaining_credits >= required_credits + + +class LLMGenerationDetail(Base): + """ + Store LLM generation details including reasoning process and tool calls. + + Association (choose one): + - For apps with Message: use message_id (one-to-one) + - For Workflow: use workflow_run_id + node_id (one run may have multiple LLM nodes) + """ + + __tablename__ = "llm_generation_details" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="llm_generation_detail_pkey"), + sa.Index("idx_llm_generation_detail_message", "message_id"), + sa.Index("idx_llm_generation_detail_workflow", "workflow_run_id", "node_id"), + sa.CheckConstraint( + "(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL)" + " OR " + "(message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)", + name="ck_llm_generation_detail_assoc_mode", + ), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Association fields (choose one) + message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, unique=True) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + node_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + + # Core data as JSON strings + reasoning_content: Mapped[str | None] = mapped_column(LongText) + tool_calls: Mapped[str | None] = mapped_column(LongText) + sequence: Mapped[str | None] = mapped_column(LongText) + + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + def to_domain_model(self) -> "LLMGenerationDetailData": + """Convert to Pydantic domain model with proper validation.""" + from core.app.entities.llm_generation_entities import LLMGenerationDetailData + + return LLMGenerationDetailData( + reasoning_content=json.loads(self.reasoning_content) if self.reasoning_content else [], + tool_calls=json.loads(self.tool_calls) if self.tool_calls else [], + sequence=json.loads(self.sequence) if self.sequence else [], + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for API response.""" + return self.to_domain_model().to_response_dict() + + @classmethod + def from_domain_model( + cls, + data: "LLMGenerationDetailData", + *, + tenant_id: str, + app_id: str, + message_id: str | None = None, + workflow_run_id: str | None = None, + node_id: str | None = None, + ) -> "LLMGenerationDetail": + """Create from Pydantic domain model.""" + # Enforce association mode at object creation time as well. + message_mode = message_id is not None + workflow_mode = workflow_run_id is not None or node_id is not None + if message_mode and workflow_mode: + raise ValueError("LLMGenerationDetail cannot set both message_id and workflow_run_id/node_id.") + if not message_mode and not (workflow_run_id and node_id): + raise ValueError("LLMGenerationDetail requires either message_id or workflow_run_id+node_id.") + + return cls( + tenant_id=tenant_id, + app_id=app_id, + message_id=message_id, + workflow_run_id=workflow_run_id, + node_id=node_id, + reasoning_content=json.dumps(data.reasoning_content) if data.reasoning_content else None, + tool_calls=json.dumps([tc.model_dump() for tc in data.tool_calls]) if data.tool_calls else None, + sequence=json.dumps([seg.model_dump() for seg in data.sequence]) if data.sequence else None, + ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 853d5afefc..5131177836 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -57,6 +57,37 @@ from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) +def is_generation_outputs(outputs: Mapping[str, Any]) -> bool: + if not outputs: + return False + + allowed_sequence_types = {"reasoning", "content", "tool_call"} + + def valid_sequence_item(item: Mapping[str, Any]) -> bool: + return isinstance(item, Mapping) and item.get("type") in allowed_sequence_types + + def valid_value(value: Any) -> bool: + if not isinstance(value, Mapping): + return False + + content = value.get("content") + reasoning_content = value.get("reasoning_content") + tool_calls = value.get("tool_calls") + sequence = value.get("sequence") + + return ( + isinstance(content, str) + and isinstance(reasoning_content, list) + and all(isinstance(item, str) for item in reasoning_content) + and isinstance(tool_calls, list) + and all(isinstance(item, Mapping) for item in tool_calls) + and isinstance(sequence, list) + and all(valid_sequence_item(item) for item in sequence) + ) + + return all(valid_value(value) for value in outputs.values()) + + class WorkflowType(StrEnum): """ Workflow Type Enum @@ -664,6 +695,10 @@ class WorkflowRun(Base): def workflow(self): return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + @property + def outputs_as_generation(self): + return is_generation_outputs(self.outputs_dict) + def to_dict(self): return { "id": self.id, @@ -677,6 +712,7 @@ class WorkflowRun(Base): "inputs": self.inputs_dict, "status": self.status, "outputs": self.outputs_dict, + "outputs_as_generation": self.outputs_as_generation, "error": self.error, "elapsed_time": self.elapsed_time, "total_tokens": self.total_tokens, diff --git a/api/services/llm_generation_service.py b/api/services/llm_generation_service.py new file mode 100644 index 0000000000..eb8327537e --- /dev/null +++ b/api/services/llm_generation_service.py @@ -0,0 +1,37 @@ +""" +LLM Generation Detail Service. + +Provides methods to query and attach generation details to workflow node executions +and messages, avoiding N+1 query problems. +""" + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.llm_generation_entities import LLMGenerationDetailData +from models import LLMGenerationDetail + + +class LLMGenerationService: + """Service for handling LLM generation details.""" + + def __init__(self, session: Session): + self._session = session + + def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None: + """Query generation detail for a specific message.""" + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id) + detail = self._session.scalars(stmt).first() + return detail.to_domain_model() if detail else None + + def get_generation_details_for_messages( + self, + message_ids: list[str], + ) -> dict[str, LLMGenerationDetailData]: + """Batch query generation details for multiple messages.""" + if not message_ids: + return {} + + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids)) + details = self._session.scalars(stmt).all() + return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id} diff --git a/api/tests/unit_tests/core/agent/__init__.py b/api/tests/unit_tests/core/agent/__init__.py new file mode 100644 index 0000000000..e7c478bf83 --- /dev/null +++ b/api/tests/unit_tests/core/agent/__init__.py @@ -0,0 +1,4 @@ +""" +Mark agent test modules as a package to avoid import name collisions. +""" + diff --git a/api/tests/unit_tests/core/agent/patterns/__init__.py b/api/tests/unit_tests/core/agent/patterns/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/agent/patterns/test_base.py b/api/tests/unit_tests/core/agent/patterns/test_base.py new file mode 100644 index 0000000000..b0e0d44940 --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_base.py @@ -0,0 +1,324 @@ +"""Tests for AgentPattern base class.""" + +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentLog, ExecutionContext +from core.agent.patterns.base import AgentPattern +from core.model_runtime.entities.llm_entities import LLMUsage + + +class ConcreteAgentPattern(AgentPattern): + """Concrete implementation of AgentPattern for testing.""" + + def run(self, prompt_messages, model_parameters, stop=[], stream=True): + """Minimal implementation for testing.""" + yield from [] + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def agent_pattern(mock_model_instance, mock_context): + """Create a concrete agent pattern for testing.""" + return ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=10, + ) + + +class TestAccumulateUsage: + """Tests for _accumulate_usage method.""" + + def test_accumulate_usage_to_empty_dict(self, agent_pattern): + """Test accumulating usage to an empty dict creates a copy.""" + total_usage: dict = {"usage": None} + delta_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + agent_pattern._accumulate_usage(total_usage, delta_usage) + + assert total_usage["usage"] is not None + assert total_usage["usage"].total_tokens == 150 + assert total_usage["usage"].prompt_tokens == 100 + assert total_usage["usage"].completion_tokens == 50 + # Verify it's a copy, not a reference + assert total_usage["usage"] is not delta_usage + + def test_accumulate_usage_adds_to_existing(self, agent_pattern): + """Test accumulating usage adds to existing values.""" + initial_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + total_usage: dict = {"usage": initial_usage} + + delta_usage = LLMUsage( + prompt_tokens=200, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.2"), + completion_tokens=100, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.2"), + total_tokens=300, + total_price=Decimal("0.4"), + currency="USD", + latency=0.5, + ) + + agent_pattern._accumulate_usage(total_usage, delta_usage) + + assert total_usage["usage"].total_tokens == 450 # 150 + 300 + assert total_usage["usage"].prompt_tokens == 300 # 100 + 200 + assert total_usage["usage"].completion_tokens == 150 # 50 + 100 + + def test_accumulate_usage_multiple_rounds(self, agent_pattern): + """Test accumulating usage across multiple rounds.""" + total_usage: dict = {"usage": None} + + # Round 1: 100 tokens + round1_usage = LLMUsage( + prompt_tokens=70, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.07"), + completion_tokens=30, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.06"), + total_tokens=100, + total_price=Decimal("0.13"), + currency="USD", + latency=0.3, + ) + agent_pattern._accumulate_usage(total_usage, round1_usage) + assert total_usage["usage"].total_tokens == 100 + + # Round 2: 150 tokens + round2_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.4, + ) + agent_pattern._accumulate_usage(total_usage, round2_usage) + assert total_usage["usage"].total_tokens == 250 # 100 + 150 + + # Round 3: 200 tokens + round3_usage = LLMUsage( + prompt_tokens=130, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.13"), + completion_tokens=70, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.14"), + total_tokens=200, + total_price=Decimal("0.27"), + currency="USD", + latency=0.5, + ) + agent_pattern._accumulate_usage(total_usage, round3_usage) + assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200 + + +class TestCreateLog: + """Tests for _create_log method.""" + + def test_create_log_with_label_and_status(self, agent_pattern): + """Test creating a log with label and status.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"key": "value"}, + ) + + assert log.label == "ROUND 1" + assert log.log_type == AgentLog.LogType.ROUND + assert log.status == AgentLog.LogStatus.START + assert log.data == {"key": "value"} + assert log.parent_id is None + + def test_create_log_with_parent_id(self, agent_pattern): + """Test creating a log with parent_id.""" + parent_log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + child_log = agent_pattern._create_log( + label="CALL tool", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={}, + parent_id=parent_log.id, + ) + + assert child_log.parent_id == parent_log.id + assert child_log.log_type == AgentLog.LogType.TOOL_CALL + + +class TestFinishLog: + """Tests for _finish_log method.""" + + def test_finish_log_updates_status(self, agent_pattern): + """Test that finish_log updates status to SUCCESS.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + finished_log = agent_pattern._finish_log(log, data={"result": "done"}) + + assert finished_log.status == AgentLog.LogStatus.SUCCESS + assert finished_log.data == {"result": "done"} + + def test_finish_log_adds_usage_metadata(self, agent_pattern): + """Test that finish_log adds usage to metadata.""" + log = agent_pattern._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + finished_log = agent_pattern._finish_log(log, usage=usage) + + assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150 + assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2") + assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD" + assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage + + +class TestFindToolByName: + """Tests for _find_tool_by_name method.""" + + def test_find_existing_tool(self, mock_model_instance, mock_context): + """Test finding an existing tool by name.""" + mock_tool = MagicMock() + mock_tool.entity.identity.name = "test_tool" + + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + found_tool = pattern._find_tool_by_name("test_tool") + assert found_tool == mock_tool + + def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context): + """Test that finding a nonexistent tool returns None.""" + mock_tool = MagicMock() + mock_tool.entity.identity.name = "test_tool" + + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + found_tool = pattern._find_tool_by_name("nonexistent_tool") + assert found_tool is None + + +class TestMaxIterationsCapping: + """Tests for max_iterations capping.""" + + def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context): + """Test that max_iterations is capped at 99.""" + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=150, + ) + + assert pattern.max_iterations == 99 + + def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context): + """Test that max_iterations is not capped when under 99.""" + pattern = ConcreteAgentPattern( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + max_iterations=50, + ) + + assert pattern.max_iterations == 50 diff --git a/api/tests/unit_tests/core/agent/patterns/test_function_call.py b/api/tests/unit_tests/core/agent/patterns/test_function_call.py new file mode 100644 index 0000000000..6b3600dbbf --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_function_call.py @@ -0,0 +1,332 @@ +"""Tests for FunctionCallStrategy.""" + +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentLog, ExecutionContext +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.message_entities import ( + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def mock_tool(): + """Create a mock tool.""" + tool = MagicMock() + tool.entity.identity.name = "test_tool" + tool.to_prompt_message_tool.return_value = PromptMessageTool( + name="test_tool", + description="A test tool", + parameters={ + "type": "object", + "properties": {"param1": {"type": "string", "description": "A parameter"}}, + "required": ["param1"], + }, + ) + return tool + + +class TestFunctionCallStrategyInit: + """Tests for FunctionCallStrategy initialization.""" + + def test_initialization(self, mock_model_instance, mock_context, mock_tool): + """Test basic initialization.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + max_iterations=10, + ) + + assert strategy.model_instance == mock_model_instance + assert strategy.context == mock_context + assert strategy.max_iterations == 10 + assert len(strategy.tools) == 1 + + def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool): + """Test initialization with tool_invoke_hook.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + mock_hook = MagicMock() + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=mock_hook, + ) + + assert strategy.tool_invoke_hook == mock_hook + + +class TestConvertToolsToPromptFormat: + """Tests for _convert_tools_to_prompt_format method.""" + + def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool): + """Test that _convert_tools_to_prompt_format returns PromptMessageTool list.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + tools = strategy._convert_tools_to_prompt_format() + + assert len(tools) == 1 + assert isinstance(tools[0], PromptMessageTool) + assert tools[0].name == "test_tool" + + def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context): + """Test that _convert_tools_to_prompt_format returns empty list when no tools.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + tools = strategy._convert_tools_to_prompt_format() + + assert tools == [] + + +class TestAgentLogGeneration: + """Tests for AgentLog generation during run.""" + + def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool): + """Test that round logs have correct structure.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + max_iterations=1, + ) + + # Create a round log + round_log = strategy._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"inputs": {"query": "test"}}, + ) + + assert round_log.label == "ROUND 1" + assert round_log.log_type == AgentLog.LogType.ROUND + assert round_log.status == AgentLog.LogStatus.START + assert "inputs" in round_log.data + + def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool): + """Test that tool call logs have correct structure.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + # Create a parent round log + round_log = strategy._create_log( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + + # Create a tool call log + tool_log = strategy._create_log( + label="CALL test_tool", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}}, + parent_id=round_log.id, + ) + + assert tool_log.label == "CALL test_tool" + assert tool_log.log_type == AgentLog.LogType.TOOL_CALL + assert tool_log.parent_id == round_log.id + assert tool_log.data["tool_name"] == "test_tool" + + +class TestToolInvocation: + """Tests for tool invocation.""" + + def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool): + """Test that tool invocation uses hook when provided.""" + from core.agent.patterns.function_call import FunctionCallStrategy + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_hook = MagicMock() + mock_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={"tool_provider_type": "test", "tool_provider": "test_id"}, + ) + mock_hook.return_value = ("Tool result", ["file-1"], mock_meta) + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=mock_hook, + ) + + result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool") + + mock_hook.assert_called_once() + assert result == "Tool result" + assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list + assert meta == mock_meta + + def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool): + """Test that tool_invoke_hook is None when not provided.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + tool_invoke_hook=None, + ) + + # Verify that tool_invoke_hook is None + assert strategy.tool_invoke_hook is None + + +class TestUsageTracking: + """Tests for usage tracking across rounds.""" + + def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context): + """Test that round usage is tracked separately from total.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + # Simulate two rounds of usage + total_usage: dict = {"usage": None} + round1_usage: dict = {"usage": None} + round2_usage: dict = {"usage": None} + + # Round 1 + usage1 = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + strategy._accumulate_usage(round1_usage, usage1) + strategy._accumulate_usage(total_usage, usage1) + + # Round 2 + usage2 = LLMUsage( + prompt_tokens=200, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.2"), + completion_tokens=100, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.2"), + total_tokens=300, + total_price=Decimal("0.4"), + currency="USD", + latency=0.5, + ) + strategy._accumulate_usage(round2_usage, usage2) + strategy._accumulate_usage(total_usage, usage2) + + # Verify round usage is separate + assert round1_usage["usage"].total_tokens == 150 + assert round2_usage["usage"].total_tokens == 300 + # Verify total is accumulated + assert total_usage["usage"].total_tokens == 450 + + +class TestPromptMessageHandling: + """Tests for prompt message handling.""" + + def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool): + """Test that messages include system and user prompts.""" + from core.agent.patterns.function_call import FunctionCallStrategy + + strategy = FunctionCallStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="You are a helpful assistant."), + UserPromptMessage(content="Hello"), + ] + + # Just verify the messages can be processed + assert len(messages) == 2 + assert isinstance(messages[0], SystemPromptMessage) + assert isinstance(messages[1], UserPromptMessage) + + def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool): + """Test that assistant messages can contain tool calls.""" + from core.model_runtime.entities.message_entities import AssistantPromptMessage + + tool_call = AssistantPromptMessage.ToolCall( + id="call_123", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name="test_tool", + arguments='{"param1": "value1"}', + ), + ) + + assistant_message = AssistantPromptMessage( + content="I'll help you with that.", + tool_calls=[tool_call], + ) + + assert len(assistant_message.tool_calls) == 1 + assert assistant_message.tool_calls[0].function.name == "test_tool" diff --git a/api/tests/unit_tests/core/agent/patterns/test_react.py b/api/tests/unit_tests/core/agent/patterns/test_react.py new file mode 100644 index 0000000000..a942ba6100 --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_react.py @@ -0,0 +1,224 @@ +"""Tests for ReActStrategy.""" + +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import ExecutionContext +from core.agent.patterns.react import ReActStrategy +from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +@pytest.fixture +def mock_tool(): + """Create a mock tool.""" + from core.model_runtime.entities.message_entities import PromptMessageTool + + tool = MagicMock() + tool.entity.identity.name = "test_tool" + tool.entity.identity.provider = "test_provider" + + # Use real PromptMessageTool for proper serialization + prompt_tool = PromptMessageTool( + name="test_tool", + description="A test tool", + parameters={"type": "object", "properties": {}}, + ) + tool.to_prompt_message_tool.return_value = prompt_tool + + return tool + + +class TestReActStrategyInit: + """Tests for ReActStrategy initialization.""" + + def test_init_with_instruction(self, mock_model_instance, mock_context): + """Test that instruction is stored correctly.""" + instruction = "You are a helpful assistant." + + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + instruction=instruction, + ) + + assert strategy.instruction == instruction + + def test_init_with_empty_instruction(self, mock_model_instance, mock_context): + """Test that empty instruction is handled correctly.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + assert strategy.instruction == "" + + +class TestBuildPromptWithReactFormat: + """Tests for _build_prompt_with_react_format method.""" + + def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool): + """Test that {{tools}} placeholder is replaced.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + system_content = "You have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + UserPromptMessage(content="Hello"), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + # The tools placeholder should be replaced with JSON + assert "{{tools}}" not in result[0].content + assert "test_tool" in result[0].content + + def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool): + """Test that {{tool_names}} placeholder is replaced.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[mock_tool], + context=mock_context, + ) + + system_content = "Valid actions: {{tool_names}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + assert "{{tool_names}}" not in result[0].content + assert '"test_tool"' in result[0].content + + def test_replace_instruction_placeholder(self, mock_model_instance, mock_context): + """Test that {{instruction}} placeholder is replaced.""" + instruction = "You are a helpful coding assistant." + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + instruction=instruction, + ) + + system_content = "{{instruction}}\n\nYou have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True, instruction) + + assert "{{instruction}}" not in result[0].content + assert instruction in result[0].content + + def test_no_tools_available_message(self, mock_model_instance, mock_context): + """Test that 'No tools available' is shown when include_tools is False.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + system_content = "You have access to: {{tools}}" + messages = [ + SystemPromptMessage(content=system_content), + ] + + result = strategy._build_prompt_with_react_format(messages, [], False) + + assert "No tools available" in result[0].content + + def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context): + """Test that agent scratchpad is appended as AssistantPromptMessage.""" + from core.agent.entities import AgentScratchpadUnit + from core.model_runtime.entities import AssistantPromptMessage + + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="System prompt"), + UserPromptMessage(content="User query"), + ] + + scratchpad = [ + AgentScratchpadUnit( + thought="I need to search for information", + action_str='{"action": "search", "action_input": "query"}', + observation="Search results here", + ) + ] + + result = strategy._build_prompt_with_react_format(messages, scratchpad, True) + + # The last message should be an AssistantPromptMessage with scratchpad content + assert len(result) == 3 + assert isinstance(result[-1], AssistantPromptMessage) + assert "I need to search for information" in result[-1].content + assert "Search results here" in result[-1].content + + def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context): + """Test that empty scratchpad doesn't add extra message.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + messages = [ + SystemPromptMessage(content="System prompt"), + UserPromptMessage(content="User query"), + ] + + result = strategy._build_prompt_with_react_format(messages, [], True) + + # Should only have the original 2 messages + assert len(result) == 2 + + def test_original_messages_not_modified(self, mock_model_instance, mock_context): + """Test that original messages list is not modified.""" + strategy = ReActStrategy( + model_instance=mock_model_instance, + tools=[], + context=mock_context, + ) + + original_content = "Original system prompt {{tools}}" + messages = [ + SystemPromptMessage(content=original_content), + ] + + strategy._build_prompt_with_react_format(messages, [], True) + + # Original message should not be modified + assert messages[0].content == original_content diff --git a/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py new file mode 100644 index 0000000000..07b9df2acf --- /dev/null +++ b/api/tests/unit_tests/core/agent/patterns/test_strategy_factory.py @@ -0,0 +1,203 @@ +"""Tests for StrategyFactory.""" + +from unittest.mock import MagicMock + +import pytest + +from core.agent.entities import AgentEntity, ExecutionContext +from core.agent.patterns.function_call import FunctionCallStrategy +from core.agent.patterns.react import ReActStrategy +from core.agent.patterns.strategy_factory import StrategyFactory +from core.model_runtime.entities.model_entities import ModelFeature + + +@pytest.fixture +def mock_model_instance(): + """Create a mock model instance.""" + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + return model_instance + + +@pytest.fixture +def mock_context(): + """Create a mock execution context.""" + return ExecutionContext( + user_id="test-user", + app_id="test-app", + conversation_id="test-conversation", + message_id="test-message", + tenant_id="test-tenant", + ) + + +class TestStrategyFactory: + """Tests for StrategyFactory.create_strategy method.""" + + def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports TOOL_CALL.""" + model_features = [ModelFeature.TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL.""" + model_features = [ModelFeature.MULTI_TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context): + """Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL.""" + model_features = [ModelFeature.STREAM_TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context): + """Test that ReActStrategy is created when model doesn't support tool calling.""" + model_features = [ModelFeature.VISION] # Only vision, no tool calling + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, ReActStrategy) + + def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context): + """Test that ReActStrategy is created when model has no features.""" + model_features: list[ModelFeature] = [] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + ) + + assert isinstance(strategy, ReActStrategy) + + def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context): + """Test explicit FUNCTION_CALLING strategy selection with model support.""" + model_features = [ModelFeature.TOOL_CALL] + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, + ) + + assert isinstance(strategy, FunctionCallStrategy) + + def test_explicit_function_calling_strategy_without_support_falls_back_to_react( + self, mock_model_instance, mock_context + ): + """Test that explicit FUNCTION_CALLING falls back to ReAct when not supported.""" + model_features: list[ModelFeature] = [] # No tool calling support + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, + ) + + # Should fall back to ReAct since FC is not supported + assert isinstance(strategy, ReActStrategy) + + def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context): + """Test explicit CHAIN_OF_THOUGHT strategy selection.""" + model_features = [ModelFeature.TOOL_CALL] # Even with tool call support + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + ) + + assert isinstance(strategy, ReActStrategy) + + def test_react_strategy_with_instruction(self, mock_model_instance, mock_context): + """Test that ReActStrategy receives instruction parameter.""" + model_features: list[ModelFeature] = [] + instruction = "You are a helpful assistant." + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + instruction=instruction, + ) + + assert isinstance(strategy, ReActStrategy) + assert strategy.instruction == instruction + + def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context): + """Test that max_iterations is passed to the strategy.""" + model_features = [ModelFeature.TOOL_CALL] + max_iterations = 5 + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + max_iterations=max_iterations, + ) + + assert strategy.max_iterations == max_iterations + + def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context): + """Test that tool_invoke_hook is passed to the strategy.""" + model_features = [ModelFeature.TOOL_CALL] + mock_hook = MagicMock() + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=mock_model_instance, + context=mock_context, + tools=[], + files=[], + tool_invoke_hook=mock_hook, + ) + + assert strategy.tool_invoke_hook == mock_hook diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py new file mode 100644 index 0000000000..d9301ccfe0 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py @@ -0,0 +1,388 @@ +"""Tests for AgentAppRunner.""" + +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult +from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.llm_entities import LLMUsage + + +class TestOrganizePromptMessages: + """Tests for _organize_prompt_messages method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + # We'll patch the class to avoid complex initialization + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + + # Set up required attributes + runner.config = MagicMock(spec=AgentEntity) + runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING + runner.config.prompt = None + + runner.app_config = MagicMock() + runner.app_config.prompt_template = MagicMock() + runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant." + + runner.history_prompt_messages = [] + runner.query = "Hello" + runner._current_thoughts = [] + runner.files = [] + runner.model_config = MagicMock() + runner.memory = None + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.file_upload_config = None + + return runner + + def test_function_calling_uses_simple_prompt(self, mock_runner): + """Test that function calling strategy uses simple_prompt_template.""" + mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="You are a helpful assistant.") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with simple_prompt_template + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "You are a helpful assistant." + + def test_chain_of_thought_uses_agent_prompt(self, mock_runner): + """Test that chain of thought strategy uses agent prompt template.""" + mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + mock_runner.config.prompt = AgentPromptEntity( + first_prompt="ReAct prompt template with {{tools}}", + next_iteration="Continue...", + ) + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="ReAct prompt template with {{tools}}") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with agent prompt + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "ReAct prompt template with {{tools}}" + + def test_chain_of_thought_without_prompt_falls_back(self, mock_runner): + """Test that chain of thought without prompt falls back to simple_prompt_template.""" + mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + mock_runner.config.prompt = None + + with patch.object(mock_runner, "_init_system_message") as mock_init: + mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")] + with patch.object(mock_runner, "_organize_user_query") as mock_query: + mock_query.return_value = [UserPromptMessage(content="Hello")] + with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform: + mock_transform.return_value.get_prompt.return_value = [ + SystemPromptMessage(content="You are a helpful assistant.") + ] + + result = mock_runner._organize_prompt_messages() + + # Verify _init_system_message was called with simple_prompt_template + mock_init.assert_called_once() + call_args = mock_init.call_args[0] + assert call_args[0] == "You are a helpful assistant." + + +class TestInitSystemMessage: + """Tests for _init_system_message method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + return runner + + def test_empty_messages_with_template(self, mock_runner): + """Test that system message is created when messages are empty.""" + result = mock_runner._init_system_message("System template", []) + + assert len(result) == 1 + assert isinstance(result[0], SystemPromptMessage) + assert result[0].content == "System template" + + def test_empty_messages_without_template(self, mock_runner): + """Test that empty list is returned when no template and no messages.""" + result = mock_runner._init_system_message("", []) + + assert result == [] + + def test_existing_system_message_not_duplicated(self, mock_runner): + """Test that system message is not duplicated if already present.""" + existing_messages = [ + SystemPromptMessage(content="Existing system"), + UserPromptMessage(content="User message"), + ] + + result = mock_runner._init_system_message("New template", existing_messages) + + # Should not insert new system message + assert len(result) == 2 + assert result[0].content == "Existing system" + + def test_system_message_inserted_when_missing(self, mock_runner): + """Test that system message is inserted when first message is not system.""" + existing_messages = [ + UserPromptMessage(content="User message"), + ] + + result = mock_runner._init_system_message("System template", existing_messages) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert result[0].content == "System template" + + +class TestClearUserPromptImageMessages: + """Tests for _clear_user_prompt_image_messages method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + return runner + + def test_text_content_unchanged(self, mock_runner): + """Test that text content is unchanged.""" + messages = [ + UserPromptMessage(content="Plain text message"), + ] + + result = mock_runner._clear_user_prompt_image_messages(messages) + + assert len(result) == 1 + assert result[0].content == "Plain text message" + + def test_original_messages_not_modified(self, mock_runner): + """Test that original messages are not modified (deep copy).""" + from core.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + TextPromptMessageContent, + ) + + messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Text part"), + ImagePromptMessageContent( + data="http://example.com/image.jpg", + format="url", + mime_type="image/jpeg", + ), + ] + ), + ] + + result = mock_runner._clear_user_prompt_image_messages(messages) + + # Original should still have list content + assert isinstance(messages[0].content, list) + # Result should have string content + assert isinstance(result[0].content, str) + + +class TestToolInvokeHook: + """Tests for _create_tool_invoke_hook method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + + runner.user_id = "test-user" + runner.tenant_id = "test-tenant" + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.trace_manager = None + runner.application_generate_entity.invoke_from = "api" + runner.application_generate_entity.app_config = MagicMock() + runner.application_generate_entity.app_config.app_id = "test-app" + runner.agent_callback = MagicMock() + runner.conversation = MagicMock() + runner.conversation.id = "test-conversation" + runner.queue_manager = MagicMock() + runner._current_message_file_ids = [] + + return runner + + def test_hook_calls_agent_invoke(self, mock_runner): + """Test that the hook calls ToolEngine.agent_invoke.""" + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_message = MagicMock() + mock_message.id = "test-message" + + mock_tool = MagicMock() + mock_tool_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={ + "tool_provider_type": "test_provider", + "tool_provider": "test_id", + }, + ) + + with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine: + mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta) + + hook = mock_runner._create_tool_invoke_hook(mock_message) + result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool") + + # Verify ToolEngine.agent_invoke was called + mock_engine.agent_invoke.assert_called_once() + + # Verify return values + assert result_content == "Tool result" + assert result_files == ["file-1", "file-2"] + assert result_meta == mock_tool_meta + + def test_hook_publishes_file_events(self, mock_runner): + """Test that the hook publishes QueueMessageFileEvent for files.""" + from core.tools.entities.tool_entities import ToolInvokeMeta + + mock_message = MagicMock() + mock_message.id = "test-message" + + mock_tool = MagicMock() + mock_tool_meta = ToolInvokeMeta( + time_cost=0.5, + error=None, + tool_config={}, + ) + + with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine: + mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta) + + hook = mock_runner._create_tool_invoke_hook(mock_message) + hook(mock_tool, {}, "test_tool") + + # Verify file events were published + assert mock_runner.queue_manager.publish.call_count == 2 + assert mock_runner._current_message_file_ids == ["file-1", "file-2"] + + +class TestAgentLogProcessing: + """Tests for AgentLog processing in run method.""" + + def test_agent_log_status_enum(self): + """Test AgentLog status enum values.""" + assert AgentLog.LogStatus.START == "start" + assert AgentLog.LogStatus.SUCCESS == "success" + assert AgentLog.LogStatus.ERROR == "error" + + def test_agent_log_metadata_enum(self): + """Test AgentLog metadata enum values.""" + assert AgentLog.LogMetadata.STARTED_AT == "started_at" + assert AgentLog.LogMetadata.FINISHED_AT == "finished_at" + assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time" + assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price" + assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens" + assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage" + + def test_agent_result_structure(self): + """Test AgentResult structure.""" + usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.1"), + completion_tokens=50, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.001"), + completion_price=Decimal("0.1"), + total_tokens=150, + total_price=Decimal("0.2"), + currency="USD", + latency=0.5, + ) + + result = AgentResult( + text="Final answer", + files=[], + usage=usage, + finish_reason="stop", + ) + + assert result.text == "Final answer" + assert result.files == [] + assert result.usage == usage + assert result.finish_reason == "stop" + + +class TestOrganizeUserQuery: + """Tests for _organize_user_query method.""" + + @pytest.fixture + def mock_runner(self): + """Create a mock AgentAppRunner for testing.""" + with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None): + from core.agent.agent_app_runner import AgentAppRunner + + runner = AgentAppRunner.__new__(AgentAppRunner) + runner.files = [] + runner.application_generate_entity = MagicMock() + runner.application_generate_entity.file_upload_config = None + return runner + + def test_simple_query_without_files(self, mock_runner): + """Test organizing a simple query without files.""" + result = mock_runner._organize_user_query("Hello world", []) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert result[0].content == "Hello world" + + def test_query_with_files(self, mock_runner): + """Test organizing a query with files.""" + from core.file.models import File + + mock_file = MagicMock(spec=File) + mock_runner.files = [mock_file] + + with patch("core.agent.agent_app_runner.file_manager") as mock_fm: + from core.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent( + data="http://example.com/image.jpg", + format="url", + mime_type="image/jpeg", + ) + + result = mock_runner._organize_user_query("Describe this image", []) + + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + assert isinstance(result[0].content, list) + assert len(result[0].content) == 2 # Image + Text diff --git a/api/tests/unit_tests/core/agent/test_entities.py b/api/tests/unit_tests/core/agent/test_entities.py new file mode 100644 index 0000000000..5136f48aab --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_entities.py @@ -0,0 +1,191 @@ +"""Tests for agent entities.""" + +from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext + + +class TestExecutionContext: + """Tests for ExecutionContext entity.""" + + def test_create_with_all_fields(self): + """Test creating ExecutionContext with all fields.""" + context = ExecutionContext( + user_id="user-123", + app_id="app-456", + conversation_id="conv-789", + message_id="msg-012", + tenant_id="tenant-345", + ) + + assert context.user_id == "user-123" + assert context.app_id == "app-456" + assert context.conversation_id == "conv-789" + assert context.message_id == "msg-012" + assert context.tenant_id == "tenant-345" + + def test_create_minimal(self): + """Test creating minimal ExecutionContext.""" + context = ExecutionContext.create_minimal(user_id="user-123") + + assert context.user_id == "user-123" + assert context.app_id is None + assert context.conversation_id is None + assert context.message_id is None + assert context.tenant_id is None + + def test_to_dict(self): + """Test converting ExecutionContext to dictionary.""" + context = ExecutionContext( + user_id="user-123", + app_id="app-456", + conversation_id="conv-789", + message_id="msg-012", + tenant_id="tenant-345", + ) + + result = context.to_dict() + + assert result == { + "user_id": "user-123", + "app_id": "app-456", + "conversation_id": "conv-789", + "message_id": "msg-012", + "tenant_id": "tenant-345", + } + + def test_with_updates(self): + """Test creating new context with updates.""" + original = ExecutionContext( + user_id="user-123", + app_id="app-456", + ) + + updated = original.with_updates(message_id="msg-789") + + # Original should be unchanged + assert original.message_id is None + # Updated should have new value + assert updated.message_id == "msg-789" + assert updated.user_id == "user-123" + assert updated.app_id == "app-456" + + +class TestAgentLog: + """Tests for AgentLog entity.""" + + def test_create_log_with_required_fields(self): + """Test creating AgentLog with required fields.""" + log = AgentLog( + label="ROUND 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"key": "value"}, + ) + + assert log.label == "ROUND 1" + assert log.log_type == AgentLog.LogType.ROUND + assert log.status == AgentLog.LogStatus.START + assert log.data == {"key": "value"} + assert log.id is not None # Auto-generated + assert log.parent_id is None + assert log.error is None + + def test_log_type_enum(self): + """Test LogType enum values.""" + assert AgentLog.LogType.ROUND == "round" + assert AgentLog.LogType.THOUGHT == "thought" + assert AgentLog.LogType.TOOL_CALL == "tool_call" + + def test_log_status_enum(self): + """Test LogStatus enum values.""" + assert AgentLog.LogStatus.START == "start" + assert AgentLog.LogStatus.SUCCESS == "success" + assert AgentLog.LogStatus.ERROR == "error" + + def test_log_metadata_enum(self): + """Test LogMetadata enum values.""" + assert AgentLog.LogMetadata.STARTED_AT == "started_at" + assert AgentLog.LogMetadata.FINISHED_AT == "finished_at" + assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time" + assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price" + assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens" + assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage" + + +class TestAgentScratchpadUnit: + """Tests for AgentScratchpadUnit entity.""" + + def test_is_final_with_final_answer_action(self): + """Test is_final returns True for Final Answer action.""" + unit = AgentScratchpadUnit( + thought="I know the answer", + action=AgentScratchpadUnit.Action( + action_name="Final Answer", + action_input="The answer is 42", + ), + ) + + assert unit.is_final() is True + + def test_is_final_with_tool_action(self): + """Test is_final returns False for tool action.""" + unit = AgentScratchpadUnit( + thought="I need to search", + action=AgentScratchpadUnit.Action( + action_name="search", + action_input={"query": "test"}, + ), + ) + + assert unit.is_final() is False + + def test_is_final_with_no_action(self): + """Test is_final returns True when no action.""" + unit = AgentScratchpadUnit( + thought="Just thinking", + ) + + assert unit.is_final() is True + + def test_action_to_dict(self): + """Test Action.to_dict method.""" + action = AgentScratchpadUnit.Action( + action_name="search", + action_input={"query": "test"}, + ) + + result = action.to_dict() + + assert result == { + "action": "search", + "action_input": {"query": "test"}, + } + + +class TestAgentEntity: + """Tests for AgentEntity.""" + + def test_strategy_enum(self): + """Test Strategy enum values.""" + assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought" + assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling" + + def test_create_with_prompt(self): + """Test creating AgentEntity with prompt.""" + prompt = AgentPromptEntity( + first_prompt="You are a helpful assistant.", + next_iteration="Continue thinking...", + ) + + entity = AgentEntity( + provider="openai", + model="gpt-4", + strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT, + prompt=prompt, + max_iteration=5, + ) + + assert entity.provider == "openai" + assert entity.model == "gpt-4" + assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT + assert entity.prompt == prompt + assert entity.max_iteration == 5 diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py new file mode 100644 index 0000000000..6a8a691a25 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_stream_chunk.py @@ -0,0 +1,48 @@ +from unittest.mock import MagicMock + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.workflow.graph_events import NodeRunStreamChunkEvent +from core.workflow.nodes import NodeType + + +class DummyQueueManager: + def __init__(self) -> None: + self.published = [] + + def publish(self, event, publish_from: PublishFrom) -> None: + self.published.append((event, publish_from)) + + +def test_skip_empty_final_chunk() -> None: + queue_manager = DummyQueueManager() + runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app") + + empty_final_event = NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + selector=["node", "text"], + chunk="", + is_final=True, + ) + + runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event) + assert queue_manager.published == [] + + normal_event = NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=NodeType.LLM, + selector=["node", "text"], + chunk="hi", + is_final=False, + ) + + runner._handle_event(workflow_entry=MagicMock(), event=normal_event) + + assert len(queue_manager.published) == 1 + published_event, publish_from = queue_manager.published[0] + assert publish_from == PublishFrom.APPLICATION_MANAGER + assert published_event.text == "hi" + diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py new file mode 100644 index 0000000000..822b6a808f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py @@ -0,0 +1,231 @@ +"""Tests for ResponseStreamCoordinator object field streaming.""" + +from unittest.mock import MagicMock + +from core.workflow.entities.tool_entities import ToolResultStatus +from core.workflow.enums import NodeType +from core.workflow.graph.graph import Graph +from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from core.workflow.graph_engine.response_coordinator.session import ResponseSession +from core.workflow.graph_events import ( + ChunkType, + NodeRunStreamChunkEvent, + ToolCall, + ToolResult, +) +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.template import Template, VariableSegment +from core.workflow.runtime import VariablePool + + +class TestResponseCoordinatorObjectStreaming: + """Test streaming of object-type variables with child fields.""" + + def test_object_field_streaming(self): + """Test that when selecting an object variable, all child field streams are forwarded.""" + # Create mock graph and variable pool + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + # Mock nodes + llm_node = MagicMock() + llm_node.id = "llm_node" + llm_node.node_type = NodeType.LLM + llm_node.execution_type = MagicMock() + llm_node.blocks_variable_output = MagicMock(return_value=False) + + response_node = MagicMock() + response_node.id = "response_node" + response_node.node_type = NodeType.ANSWER + response_node.execution_type = MagicMock() + response_node.blocks_variable_output = MagicMock(return_value=False) + + # Mock template for response node + response_node.node_data = MagicMock(spec=BaseNodeData) + response_node.node_data.answer = "{{#llm_node.generation#}}" + + graph.nodes = { + "llm_node": llm_node, + "response_node": response_node, + } + graph.root_node = llm_node + graph.get_outgoing_edges = MagicMock(return_value=[]) + + # Create coordinator + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Track execution + coordinator.track_node_execution("llm_node", "exec_123") + coordinator.track_node_execution("response_node", "exec_456") + + # Simulate streaming events for child fields of generation object + # 1. Content stream + content_event_1 = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "content"], + chunk="Hello", + is_final=False, + chunk_type=ChunkType.TEXT, + ) + content_event_2 = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "content"], + chunk=" world", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + # 2. Tool call stream + tool_call_event = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "tool_calls"], + chunk='{"query": "test"}', + is_final=True, + chunk_type=ChunkType.TOOL_CALL, + tool_call=ToolCall( + id="call_123", + name="search", + arguments='{"query": "test"}', + ), + ) + + # 3. Tool result stream + tool_result_event = NodeRunStreamChunkEvent( + id="exec_123", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["llm_node", "generation", "tool_results"], + chunk="Found 10 results", + is_final=True, + chunk_type=ChunkType.TOOL_RESULT, + tool_result=ToolResult( + id="call_123", + name="search", + output="Found 10 results", + files=[], + status=ToolResultStatus.SUCCESS, + ), + ) + + # Intercept these events + coordinator.intercept_event(content_event_1) + coordinator.intercept_event(tool_call_event) + coordinator.intercept_event(tool_result_event) + coordinator.intercept_event(content_event_2) + + # Verify that all child streams are buffered + assert ("llm_node", "generation", "content") in coordinator._stream_buffers + assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers + assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers + + # Verify payloads are preserved in buffered events + buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0] + assert buffered_call.tool_call is not None + assert buffered_call.tool_call.id == "call_123" + buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0] + assert buffered_result.tool_result is not None + assert buffered_result.tool_result.status == "success" + + # Verify we can find child streams + child_streams = coordinator._find_child_streams(["llm_node", "generation"]) + assert len(child_streams) == 3 + assert ("llm_node", "generation", "content") in child_streams + assert ("llm_node", "generation", "tool_calls") in child_streams + assert ("llm_node", "generation", "tool_results") in child_streams + + def test_find_child_streams(self): + """Test the _find_child_streams method.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Add some mock streams + coordinator._stream_buffers = { + ("node1", "generation", "content"): [], + ("node1", "generation", "tool_calls"): [], + ("node1", "generation", "thought"): [], + ("node1", "text"): [], # Not a child of generation + ("node2", "generation", "content"): [], # Different node + } + + # Find children of node1.generation + children = coordinator._find_child_streams(["node1", "generation"]) + + assert len(children) == 3 + assert ("node1", "generation", "content") in children + assert ("node1", "generation", "tool_calls") in children + assert ("node1", "generation", "thought") in children + assert ("node1", "text") not in children + assert ("node2", "generation", "content") not in children + + def test_find_child_streams_with_closed_streams(self): + """Test that _find_child_streams also considers closed streams.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + + # Add some streams - some buffered, some closed + coordinator._stream_buffers = { + ("node1", "generation", "content"): [], + } + coordinator._closed_streams = { + ("node1", "generation", "tool_calls"), + ("node1", "generation", "thought"), + } + + # Should find all children regardless of whether they're in buffers or closed + children = coordinator._find_child_streams(["node1", "generation"]) + + assert len(children) == 3 + assert ("node1", "generation", "content") in children + assert ("node1", "generation", "tool_calls") in children + assert ("node1", "generation", "thought") in children + + def test_special_selector_rewrites_to_active_response_node(self): + """Ensure special selectors attribute streams to the active response node.""" + graph = MagicMock(spec=Graph) + variable_pool = MagicMock(spec=VariablePool) + + response_node = MagicMock() + response_node.id = "response_node" + response_node.node_type = NodeType.ANSWER + graph.nodes = {"response_node": response_node} + graph.root_node = response_node + + coordinator = ResponseStreamCoordinator(variable_pool, graph) + coordinator.track_node_execution("response_node", "exec_resp") + + coordinator._active_session = ResponseSession( + node_id="response_node", + template=Template(segments=[VariableSegment(selector=["sys", "foo"])]), + ) + + event = NodeRunStreamChunkEvent( + id="stream_1", + node_id="llm_node", + node_type=NodeType.LLM, + selector=["sys", "foo"], + chunk="hi", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + coordinator._stream_buffers[("sys", "foo")] = [event] + coordinator._stream_positions[("sys", "foo")] = 0 + coordinator._closed_streams.add(("sys", "foo")) + + events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"])) + + assert is_complete + assert len(events) == 1 + rewritten = events[0] + assert rewritten.node_id == "response_node" + assert rewritten.id == "exec_resp" diff --git a/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py new file mode 100644 index 0000000000..951149e933 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py @@ -0,0 +1,328 @@ +"""Tests for StreamChunkEvent and its subclasses.""" + +from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus +from core.workflow.node_events import ( + ChunkType, + StreamChunkEvent, + ThoughtChunkEvent, + ToolCallChunkEvent, + ToolResultChunkEvent, +) + + +class TestChunkType: + """Tests for ChunkType enum.""" + + def test_chunk_type_values(self): + """Test that ChunkType has expected values.""" + assert ChunkType.TEXT == "text" + assert ChunkType.TOOL_CALL == "tool_call" + assert ChunkType.TOOL_RESULT == "tool_result" + assert ChunkType.THOUGHT == "thought" + + def test_chunk_type_is_str_enum(self): + """Test that ChunkType values are strings.""" + for chunk_type in ChunkType: + assert isinstance(chunk_type.value, str) + + +class TestStreamChunkEvent: + """Tests for base StreamChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating StreamChunkEvent with required fields.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="Hello", + ) + + assert event.selector == ["node1", "text"] + assert event.chunk == "Hello" + assert event.is_final is False + assert event.chunk_type == ChunkType.TEXT + + def test_create_with_all_fields(self): + """Test creating StreamChunkEvent with all fields.""" + event = StreamChunkEvent( + selector=["node1", "output"], + chunk="World", + is_final=True, + chunk_type=ChunkType.TEXT, + ) + + assert event.selector == ["node1", "output"] + assert event.chunk == "World" + assert event.is_final is True + assert event.chunk_type == ChunkType.TEXT + + def test_default_chunk_type_is_text(self): + """Test that default chunk_type is TEXT.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="test", + ) + + assert event.chunk_type == ChunkType.TEXT + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = StreamChunkEvent( + selector=["node1", "text"], + chunk="Hello", + is_final=True, + ) + + data = event.model_dump() + + assert data["selector"] == ["node1", "text"] + assert data["chunk"] == "Hello" + assert data["is_final"] is True + assert data["chunk_type"] == "text" + + +class TestToolCallChunkEvent: + """Tests for ToolCallChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ToolCallChunkEvent with required fields.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"city": "Beijing"}', + tool_call=ToolCall(id="call_123", name="weather", arguments=None), + ) + + assert event.selector == ["node1", "tool_calls"] + assert event.chunk == '{"city": "Beijing"}' + assert event.tool_call.id == "call_123" + assert event.tool_call.name == "weather" + assert event.chunk_type == ChunkType.TOOL_CALL + + def test_chunk_type_is_tool_call(self): + """Test that chunk_type is always TOOL_CALL.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk="", + tool_call=ToolCall(id="call_123", name="test_tool", arguments=None), + ) + + assert event.chunk_type == ChunkType.TOOL_CALL + + def test_tool_arguments_field(self): + """Test tool_arguments field.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"param": "value"}', + tool_call=ToolCall( + id="call_123", + name="test_tool", + arguments='{"param": "value"}', + ), + ) + + assert event.tool_call.arguments == '{"param": "value"}' + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk='{"city": "Beijing"}', + tool_call=ToolCall( + id="call_123", + name="weather", + arguments='{"city": "Beijing"}', + ), + is_final=True, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "tool_call" + assert data["tool_call"]["id"] == "call_123" + assert data["tool_call"]["name"] == "weather" + assert data["tool_call"]["arguments"] == '{"city": "Beijing"}' + assert data["is_final"] is True + + +class TestToolResultChunkEvent: + """Tests for ToolResultChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ToolResultChunkEvent with required fields.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="Weather: Sunny, 25°C", + tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"), + ) + + assert event.selector == ["node1", "tool_results"] + assert event.chunk == "Weather: Sunny, 25°C" + assert event.tool_result.id == "call_123" + assert event.tool_result.name == "weather" + assert event.chunk_type == ChunkType.TOOL_RESULT + + def test_chunk_type_is_tool_result(self): + """Test that chunk_type is always TOOL_RESULT.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult(id="call_123", name="test_tool"), + ) + + assert event.chunk_type == ChunkType.TOOL_RESULT + + def test_tool_files_default_empty(self): + """Test that tool_files defaults to empty list.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult(id="call_123", name="test_tool"), + ) + + assert event.tool_result.files == [] + + def test_tool_files_with_values(self): + """Test tool_files with file IDs.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult( + id="call_123", + name="test_tool", + files=["file_1", "file_2"], + ), + ) + + assert event.tool_result.files == ["file_1", "file_2"] + + def test_tool_error_output(self): + """Test error output captured in tool_result.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="", + tool_result=ToolResult( + id="call_123", + name="test_tool", + output="Tool execution failed", + status=ToolResultStatus.ERROR, + ), + ) + + assert event.tool_result.output == "Tool execution failed" + assert event.tool_result.status == ToolResultStatus.ERROR + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="Weather: Sunny", + tool_result=ToolResult( + id="call_123", + name="weather", + output="Weather: Sunny", + files=["file_1"], + status=ToolResultStatus.SUCCESS, + ), + is_final=True, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "tool_result" + assert data["tool_result"]["id"] == "call_123" + assert data["tool_result"]["name"] == "weather" + assert data["tool_result"]["files"] == ["file_1"] + assert data["is_final"] is True + + +class TestThoughtChunkEvent: + """Tests for ThoughtChunkEvent.""" + + def test_create_with_required_fields(self): + """Test creating ThoughtChunkEvent with required fields.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="I need to query the weather...", + ) + + assert event.selector == ["node1", "thought"] + assert event.chunk == "I need to query the weather..." + assert event.chunk_type == ChunkType.THOUGHT + + def test_chunk_type_is_thought(self): + """Test that chunk_type is always THOUGHT.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="thinking...", + ) + + assert event.chunk_type == ChunkType.THOUGHT + + def test_serialization(self): + """Test that event can be serialized to dict.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="I need to analyze this...", + is_final=False, + ) + + data = event.model_dump() + + assert data["chunk_type"] == "thought" + assert data["chunk"] == "I need to analyze this..." + assert data["is_final"] is False + + +class TestEventInheritance: + """Tests for event inheritance relationships.""" + + def test_tool_call_is_stream_chunk(self): + """Test that ToolCallChunkEvent is a StreamChunkEvent.""" + event = ToolCallChunkEvent( + selector=["node1", "tool_calls"], + chunk="", + tool_call=ToolCall(id="call_123", name="test", arguments=None), + ) + + assert isinstance(event, StreamChunkEvent) + + def test_tool_result_is_stream_chunk(self): + """Test that ToolResultChunkEvent is a StreamChunkEvent.""" + event = ToolResultChunkEvent( + selector=["node1", "tool_results"], + chunk="result", + tool_result=ToolResult(id="call_123", name="test"), + ) + + assert isinstance(event, StreamChunkEvent) + + def test_thought_is_stream_chunk(self): + """Test that ThoughtChunkEvent is a StreamChunkEvent.""" + event = ThoughtChunkEvent( + selector=["node1", "thought"], + chunk="thinking...", + ) + + assert isinstance(event, StreamChunkEvent) + + def test_all_events_have_common_fields(self): + """Test that all events have common StreamChunkEvent fields.""" + events = [ + StreamChunkEvent(selector=["n", "t"], chunk="a"), + ToolCallChunkEvent( + selector=["n", "t"], + chunk="b", + tool_call=ToolCall(id="1", name="t", arguments=None), + ), + ToolResultChunkEvent( + selector=["n", "t"], + chunk="c", + tool_result=ToolResult(id="1", name="t"), + ), + ThoughtChunkEvent(selector=["n", "t"], chunk="d"), + ] + + for event in events: + assert hasattr(event, "selector") + assert hasattr(event, "chunk") + assert hasattr(event, "is_final") + assert hasattr(event, "chunk_type") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py new file mode 100644 index 0000000000..55f6525bcc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_llm_node_streaming.py @@ -0,0 +1,149 @@ +import types +from collections.abc import Generator +from typing import Any + +import pytest + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities import ToolCallResult +from core.workflow.entities.tool_entities import ToolResultStatus +from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase +from core.workflow.nodes.llm.node import LLMNode + + +class _StubModelInstance: + """Minimal stub to satisfy _stream_llm_events signature.""" + + provider_model_bundle = None + + +def _drain(generator: Generator[NodeEventBase, None, Any]): + events: list = [] + try: + while True: + events.append(next(generator)) + except StopIteration as exc: + return events, exc.value + + +@pytest.fixture(autouse=True) +def patch_deduct_llm_quota(monkeypatch): + # Avoid touching real quota logic during unit tests + monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None) + + +def _make_llm_node(reasoning_format: str) -> LLMNode: + node = LLMNode.__new__(LLMNode) + object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[])) + object.__setattr__(node, "tenant_id", "tenant") + return node + + +def test_stream_llm_events_extracts_reasoning_for_tagged(): + node = _make_llm_node(reasoning_format="tagged") + tagged_text = "ThoughtAnswer" + usage = LLMUsage.empty_usage() + + def generator(): + yield ModelInvokeCompletedEvent( + text=tagged_text, + usage=usage, + finish_reason="stop", + reasoning_content="", + structured_output=None, + ) + + events, returned = _drain( + node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None)) + ) + + assert events == [] + clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned + assert clean_text == tagged_text # original preserved for output + assert reasoning_content == "" # tagged mode keeps reasoning separate + assert gen_clean == "Answer" # stripped content for generation + assert gen_reasoning == "Thought" # reasoning extracted from tag + assert ret_usage == usage + assert finish_reason == "stop" + assert structured is None + assert gen_data is None + + # generation building should include reasoning and sequence + generation_content = gen_clean or clean_text + sequence = [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len(generation_content)}, + ] + assert sequence == [ + {"type": "reasoning", "index": 0}, + {"type": "content", "start": 0, "end": len("Answer")}, + ] + + +def test_stream_llm_events_no_reasoning_results_in_empty_sequence(): + node = _make_llm_node(reasoning_format="tagged") + plain_text = "Hello world" + usage = LLMUsage.empty_usage() + + def generator(): + yield ModelInvokeCompletedEvent( + text=plain_text, + usage=usage, + finish_reason=None, + reasoning_content="", + structured_output=None, + ) + + events, returned = _drain( + node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None)) + ) + + assert events == [] + _, _, gen_reasoning, gen_clean, *_ = returned + generation_content = gen_clean or plain_text + assert gen_reasoning == "" + assert generation_content == plain_text + # Empty reasoning should imply empty sequence in generation construction + sequence = [] + assert sequence == [] + + +def test_serialize_tool_call_strips_files_to_ids(): + file_cls = pytest.importorskip("core.file").File + file_type = pytest.importorskip("core.file.enums").FileType + transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod + + file_with_id = file_cls( + id="f1", + tenant_id="t", + type=file_type.IMAGE, + transfer_method=transfer_method.REMOTE_URL, + remote_url="http://example.com/f1", + storage_key="k1", + ) + file_with_related = file_cls( + id=None, + tenant_id="t", + type=file_type.IMAGE, + transfer_method=transfer_method.REMOTE_URL, + related_id="rel2", + remote_url="http://example.com/f2", + storage_key="k2", + ) + tool_call = ToolCallResult( + id="tc", + name="do", + arguments='{"a":1}', + output="ok", + files=[file_with_id, file_with_related], + status=ToolResultStatus.SUCCESS, + ) + + serialized = LLMNode._serialize_tool_call(tool_call) + + assert serialized["files"] == ["f1", "rel2"] + assert serialized["id"] == "tc" + assert serialized["name"] == "do" + assert serialized["arguments"] == '{"a":1}' + assert serialized["output"] == "ok" +