From 96641a93f6d206bb7ad55c8b370f16bae538aa8a Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Wed, 8 Apr 2026 12:31:23 +0800 Subject: [PATCH] feat(api): add Agent V2 node and new Agent app type (Phase 1-3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a new unified Agent V2 workflow node that combines LLM capabilities with agent tool-calling loops, along with a new AppMode.AGENT for standalone agent apps backed by single-node workflows. Phase 1 — Agent Patterns: - Add core/agent/patterns/ module (AgentPattern, FunctionCallStrategy, ReActStrategy, StrategyFactory) ported from feat/support-agent-sandbox - Add ExecutionContext, AgentLog, AgentResult entities - Add Tool.to_prompt_message_tool() for LLM-consumable tool conversion Phase 2 — Agent V2 Workflow Node: - Add core/workflow/nodes/agent_v2/ (AgentV2Node, AgentV2NodeData, AgentV2ToolManager, AgentV2EventAdapter) - Register agent-v2 node type in DifyNodeFactory - No-tools path: single LLM call (LLM Node equivalent) - Tools path: FC/ReAct loop via StrategyFactory Phase 3 — Agent App Type: - Add AppMode.AGENT to model enum - Add WorkflowGraphFactory for auto-generating start->agent_v2->answer graphs - AppService.create_app() creates workflow draft for AGENT mode - AppGenerateService.generate() routes AGENT to AdvancedChatAppGenerator - Console API and DSL import/export support AGENT mode - Default app template for AGENT mode Old agent/agent-chat/LLM node paths are fully preserved. 38 unit tests all passing. Made-with: Cursor --- api/constants/model_template.py | 16 + api/controllers/console/app/app.py | 8 +- api/core/agent/entities.py | 78 +++ api/core/agent/patterns/__init__.py | 19 + api/core/agent/patterns/base.py | 506 ++++++++++++++++++ api/core/agent/patterns/function_call.py | 359 +++++++++++++ api/core/agent/patterns/react.py | 419 +++++++++++++++ api/core/agent/patterns/strategy_factory.py | 108 ++++ api/core/tools/__base/tool.py | 57 ++ api/core/workflow/node_factory.py | 10 + api/core/workflow/nodes/agent_v2/__init__.py | 4 + api/core/workflow/nodes/agent_v2/entities.py | 86 +++ .../workflow/nodes/agent_v2/event_adapter.py | 96 ++++ api/core/workflow/nodes/agent_v2/node.py | 370 +++++++++++++ .../workflow/nodes/agent_v2/tool_manager.py | 122 +++++ api/models/model.py | 1 + api/services/app_dsl_service.py | 4 +- api/services/app_generate_service.py | 48 ++ api/services/app_service.py | 32 ++ api/services/workflow/graph_factory.py | 113 ++++ .../core/workflow/nodes/agent_v2/__init__.py | 0 .../nodes/agent_v2/test_agent_v2_basic.py | 332 ++++++++++++ .../nodes/agent_v2/test_agent_v2_phase3.py | 132 +++++ 23 files changed, 2915 insertions(+), 5 deletions(-) create mode 100644 api/core/agent/patterns/__init__.py create mode 100644 api/core/agent/patterns/base.py create mode 100644 api/core/agent/patterns/function_call.py create mode 100644 api/core/agent/patterns/react.py create mode 100644 api/core/agent/patterns/strategy_factory.py create mode 100644 api/core/workflow/nodes/agent_v2/__init__.py create mode 100644 api/core/workflow/nodes/agent_v2/entities.py create mode 100644 api/core/workflow/nodes/agent_v2/event_adapter.py create mode 100644 api/core/workflow/nodes/agent_v2/node.py create mode 100644 api/core/workflow/nodes/agent_v2/tool_manager.py create mode 100644 api/services/workflow/graph_factory.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent_v2/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_basic.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_phase3.py diff --git a/api/constants/model_template.py b/api/constants/model_template.py index cacf6b6874..d59405e548 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -81,4 +81,20 @@ default_app_templates: Mapping[AppMode, Mapping] = { }, }, }, + # agent default mode (new agent backed by single-node workflow) + AppMode.AGENT: { + "app": { + "mode": AppMode.AGENT, + "enable_site": True, + "enable_api": True, + }, + "model_config": { + "model": { + "provider": "openai", + "name": "gpt-4o", + "mode": "chat", + "completion_params": {}, + }, + }, + }, } diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c4b9bf6540..8bc28aba5a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -51,7 +51,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.feature_service import FeatureService -ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] +ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] register_enum_models(console_ns, IconType) @@ -61,7 +61,7 @@ _logger = logging.getLogger(__name__) class AppListQuery(BaseModel): page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)") - mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field( + mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field( default="all", description="App mode filter" ) name: str | None = Field(default=None, description="Filter by app name") @@ -93,7 +93,9 @@ class AppListQuery(BaseModel): class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") + mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] = Field( + ..., description="App mode" + ) icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 220feced1d..6ec76a9f99 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,3 +1,5 @@ +import uuid +from collections.abc import Mapping from enum import StrEnum from typing import Any, Union @@ -92,3 +94,79 @@ class AgentInvokeMessage(ToolInvokeMessage): """ pass + + +class ExecutionContext(BaseModel): + """Execution context containing trace and audit information. + + Carries IDs and metadata needed for tracing, auditing, and correlation + but not part of the core business logic. + """ + + user_id: str | None = None + app_id: str | None = None + conversation_id: str | None = None + message_id: str | None = None + tenant_id: str | None = None + + @classmethod + def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext": + return cls(user_id=user_id) + + def to_dict(self) -> dict[str, Any]: + return { + "user_id": self.user_id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "message_id": self.message_id, + "tenant_id": self.tenant_id, + } + + def with_updates(self, **kwargs) -> "ExecutionContext": + data = self.to_dict() + data.update(kwargs) + return ExecutionContext(**{k: v for k, v in data.items() if k in ExecutionContext.model_fields}) + + +class AgentLog(BaseModel): + """Structured log entry for agent execution tracing.""" + + class LogType(StrEnum): + ROUND = "round" + THOUGHT = "thought" + TOOL_CALL = "tool_call" + + class LogMetadata(StrEnum): + STARTED_AT = "started_at" + FINISHED_AT = "finished_at" + ELAPSED_TIME = "elapsed_time" + TOTAL_PRICE = "total_price" + TOTAL_TOKENS = "total_tokens" + PROVIDER = "provider" + CURRENCY = "currency" + LLM_USAGE = "llm_usage" + ICON = "icon" + ICON_DARK = "icon_dark" + + class LogStatus(StrEnum): + START = "start" + ERROR = "error" + SUCCESS = "success" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + label: str = Field(...) + log_type: LogType = Field(...) + parent_id: str | None = Field(default=None) + error: str | None = Field(default=None) + status: LogStatus = Field(...) + data: Mapping[str, Any] = Field(...) + metadata: Mapping[LogMetadata, Any] = Field(default={}) + + +class AgentResult(BaseModel): + """Agent execution result.""" + + text: str = Field(default="") + files: list[Any] = Field(default_factory=list) + usage: Any | None = Field(default=None) + finish_reason: str | None = Field(default=None) diff --git a/api/core/agent/patterns/__init__.py b/api/core/agent/patterns/__init__.py new file mode 100644 index 0000000000..8a3b125533 --- /dev/null +++ b/api/core/agent/patterns/__init__.py @@ -0,0 +1,19 @@ +"""Agent patterns module. + +This module provides different strategies for agent execution: +- FunctionCallStrategy: Uses native function/tool calling +- ReActStrategy: Uses ReAct (Reasoning + Acting) approach +- StrategyFactory: Factory for creating strategies based on model features +""" + +from .base import AgentPattern +from .function_call import FunctionCallStrategy +from .react import ReActStrategy +from .strategy_factory import StrategyFactory + +__all__ = [ + "AgentPattern", + "FunctionCallStrategy", + "ReActStrategy", + "StrategyFactory", +] diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py new file mode 100644 index 0000000000..7d182dbc84 --- /dev/null +++ b/api/core/agent/patterns/base.py @@ -0,0 +1,506 @@ +"""Base class for agent strategies.""" + +from __future__ import annotations + +import json +import re +import time +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator +from typing import TYPE_CHECKING, Any + +from core.agent.entities import AgentLog, AgentResult, ExecutionContext +from core.model_manager import ModelInstance +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta +from graphon.file import File +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + PromptMessageTool, +) +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + +# Type alias for tool invoke hook +# Returns: (response_content, message_file_ids, tool_invoke_meta) +ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]] + + +class AgentPattern(ABC): + """Base class for agent execution strategies.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + ): + """Initialize the agent strategy.""" + self.model_instance = model_instance + self.tools = tools + self.context = context + self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations + self.workflow_call_depth = workflow_call_depth + self.files: list[File] = files + self.tool_invoke_hook = tool_invoke_hook + + @abstractmethod + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the agent strategy.""" + pass + + def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None: + """Accumulate LLM usage statistics.""" + if not total_usage.get("usage"): + # Create a copy to avoid modifying the original + total_usage["usage"] = LLMUsage( + prompt_tokens=delta_usage.prompt_tokens, + prompt_unit_price=delta_usage.prompt_unit_price, + prompt_price_unit=delta_usage.prompt_price_unit, + prompt_price=delta_usage.prompt_price, + completion_tokens=delta_usage.completion_tokens, + completion_unit_price=delta_usage.completion_unit_price, + completion_price_unit=delta_usage.completion_price_unit, + completion_price=delta_usage.completion_price, + total_tokens=delta_usage.total_tokens, + total_price=delta_usage.total_price, + currency=delta_usage.currency, + latency=delta_usage.latency, + ) + else: + current: LLMUsage = total_usage["usage"] + current.prompt_tokens += delta_usage.prompt_tokens + current.completion_tokens += delta_usage.completion_tokens + current.total_tokens += delta_usage.total_tokens + current.prompt_price += delta_usage.prompt_price + current.completion_price += delta_usage.completion_price + current.total_price += delta_usage.total_price + + def _extract_content(self, content: Any) -> str: + """Extract text content from message content.""" + if isinstance(content, list): + # Content items are PromptMessageContentUnionTypes + text_parts = [] + for c in content: + # Check if it's a TextPromptMessageContent (which has data attribute) + if isinstance(c, TextPromptMessageContent): + text_parts.append(c.data) + return "".join(text_parts) + return str(content) + + def _has_tool_calls(self, chunk: LLMResultChunk) -> bool: + """Check if chunk contains tool calls.""" + # LLMResultChunk always has delta attribute + return bool(chunk.delta.message and chunk.delta.message.tool_calls) + + def _has_tool_calls_result(self, result: LLMResult) -> bool: + """Check if result contains tool calls (non-streaming).""" + # LLMResult always has message attribute + return bool(result.message and result.message.tool_calls) + + def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from streaming chunk.""" + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + if chunk.delta.message and chunk.delta.message.tool_calls: + for tool_call in chunk.delta.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]: + """Extract tool calls from non-streaming result.""" + tool_calls = [] + if result.message and result.message.tool_calls: + for tool_call in result.message.tool_calls: + if tool_call.function: + try: + args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + except json.JSONDecodeError: + args = {} + tool_calls.append((tool_call.id or "", tool_call.function.name, args)) + return tool_calls + + def _extract_text_from_message(self, message: PromptMessage) -> str: + """Extract text content from a prompt message.""" + # PromptMessage always has content attribute + content = message.content + if isinstance(content, str): + return content + elif isinstance(content, list): + # Extract text from content list + text_parts = [] + for item in content: + if isinstance(item, TextPromptMessageContent): + text_parts.append(item.data) + return " ".join(text_parts) + return "" + + def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]: + """Get metadata for a tool including provider and icon info.""" + from core.tools.tool_manager import ToolManager + + metadata: dict[AgentLog.LogMetadata, Any] = {} + if tool_instance.entity and tool_instance.entity.identity: + identity = tool_instance.entity.identity + if identity.provider: + metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider + + # Get icon using ToolManager for proper URL generation + tenant_id = self.context.tenant_id + if tenant_id and identity.provider: + try: + provider_type = tool_instance.tool_provider_type() + icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider) + if isinstance(icon, str): + metadata[AgentLog.LogMetadata.ICON] = icon + elif isinstance(icon, dict): + # Handle icon dict with background/content or light/dark variants + metadata[AgentLog.LogMetadata.ICON] = icon + except Exception: + # Fallback to identity.icon if ToolManager fails + if identity.icon: + metadata[AgentLog.LogMetadata.ICON] = identity.icon + elif identity.icon: + metadata[AgentLog.LogMetadata.ICON] = identity.icon + return metadata + + def _create_log( + self, + label: str, + log_type: AgentLog.LogType, + status: AgentLog.LogStatus, + data: dict[str, Any] | None = None, + parent_id: str | None = None, + extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None, + ) -> AgentLog: + """Create a new AgentLog with standard metadata.""" + metadata: dict[AgentLog.LogMetadata, Any] = { + AgentLog.LogMetadata.STARTED_AT: time.perf_counter(), + } + if extra_metadata: + metadata.update(extra_metadata) + + return AgentLog( + label=label, + log_type=log_type, + status=status, + data=data or {}, + parent_id=parent_id, + metadata=metadata, + ) + + def _finish_log( + self, + log: AgentLog, + data: dict[str, Any] | None = None, + usage: LLMUsage | None = None, + ) -> AgentLog: + """Finish an AgentLog by updating its status and metadata.""" + log.status = AgentLog.LogStatus.SUCCESS + + if data is not None: + log.data = data + + # Calculate elapsed time + started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter()) + finished_at = time.perf_counter() + + # Update metadata + log.metadata = { + **log.metadata, + AgentLog.LogMetadata.FINISHED_AT: finished_at, + # Calculate elapsed time in seconds + AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4), + } + + # Add usage information if provided + if usage: + log.metadata.update( + { + AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price, + AgentLog.LogMetadata.CURRENCY: usage.currency, + AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens, + AgentLog.LogMetadata.LLM_USAGE: usage, + } + ) + + return log + + def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]: + """ + Replace file references in tool arguments with actual File objects. + + Args: + tool_args: Dictionary of tool arguments + + Returns: + Updated tool arguments with file references replaced + """ + # Process each argument in the dictionary + processed_args: dict[str, Any] = {} + for key, value in tool_args.items(): + processed_args[key] = self._process_file_reference(value) + return processed_args + + def _process_file_reference(self, data: Any) -> Any: + """ + Recursively process data to replace file references. + Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...]. + + Args: + data: The data to process (can be dict, list, str, or other types) + + Returns: + Processed data with file references replaced + """ + single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$") + multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$") + + if isinstance(data, dict): + # Process dictionary recursively + return {key: self._process_file_reference(value) for key, value in data.items()} + elif isinstance(data, list): + # Process list recursively + return [self._process_file_reference(item) for item in data] + elif isinstance(data, str): + # Check for single file pattern [File: file_id] + single_match = single_file_pattern.match(data.strip()) + if single_match: + file_id = single_match.group(1).strip() + # Find the file in self.files + for file in self.files: + if file.id and str(file.id) == file_id: + return file + # If file not found, return original value + return data + + # Check for multiple files pattern [Files: file_id1, file_id2, ...] + multiple_match = multiple_files_pattern.match(data.strip()) + if multiple_match: + file_ids_str = multiple_match.group(1).strip() + # Split by comma and strip whitespace + file_ids = [fid.strip() for fid in file_ids_str.split(",")] + + # Find all matching files + matched_files: list[File] = [] + for file_id in file_ids: + for file in self.files: + if file.id and str(file.id) == file_id: + matched_files.append(file) + break + + # Return list of files if any were found, otherwise return original + return matched_files or data + + return data + else: + # Return other types as-is + return data + + def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk: + """Create a text chunk for streaming.""" + return LLMResultChunk( + model=self.model_instance.model_name, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + usage=None, + ), + system_fingerprint="", + ) + + def _invoke_tool( + self, + tool_instance: Tool, + tool_args: dict[str, Any], + tool_name: str, + ) -> tuple[str, list[File], ToolInvokeMeta | None]: + """ + Invoke a tool and collect its response. + + Args: + tool_instance: The tool instance to invoke + tool_args: Tool arguments + tool_name: Name of the tool + + Returns: + Tuple of (response_content, tool_files, tool_invoke_meta) + """ + # Process tool_args to replace file references with actual File objects + tool_args = self._replace_file_references(tool_args) + + # If a tool invoke hook is set, use it instead of generic_invoke + if self.tool_invoke_hook: + response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name) + # Note: message_file_ids are stored in DB, we don't convert them to File objects here + # The caller (AgentAppRunner) handles file publishing + return response_content, [], tool_invoke_meta + + # Default: use generic_invoke for workflow scenarios + # Import here to avoid circular import + from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine + + tool_response = ToolEngine.generic_invoke( + tool=tool_instance, + tool_parameters=tool_args, + user_id=self.context.user_id or "", + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, + app_id=self.context.app_id, + conversation_id=self.context.conversation_id, + message_id=self.context.message_id, + ) + + # Collect response and files + response_content = "" + tool_files: list[File] = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(response.message, ToolInvokeMessage.TextMessage) + response_content += response.message.text + + elif response.type == ToolInvokeMessage.MessageType.LINK: + # Handle link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Link: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE: + # Handle image URL messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK: + # Handle image link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Image: {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK: + # Handle binary file link messages + if isinstance(response.message, ToolInvokeMessage.TextMessage): + filename = response.meta.get("filename", "file") if response.meta else "file" + response_content += f"[File: {filename} - {response.message.text}]" + + elif response.type == ToolInvokeMessage.MessageType.JSON: + # Handle JSON messages + if isinstance(response.message, ToolInvokeMessage.JsonMessage): + response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2) + + elif response.type == ToolInvokeMessage.MessageType.BLOB: + # Handle blob messages - convert to text representation + if isinstance(response.message, ToolInvokeMessage.BlobMessage): + mime_type = ( + response.meta.get("mime_type", "application/octet-stream") + if response.meta + else "application/octet-stream" + ) + size = len(response.message.blob) + response_content += f"[Binary data: {mime_type}, size: {size} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + # Handle variable messages + if isinstance(response.message, ToolInvokeMessage.VariableMessage): + var_name = response.message.variable_name + var_value = response.message.variable_value + if isinstance(var_value, str): + response_content += var_value + else: + response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]" + + elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: + # Handle blob chunk messages - these are parts of a larger blob + if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage): + response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]" + + elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: + # Handle retriever resources messages + if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage): + response_content += response.message.context + + elif response.type == ToolInvokeMessage.MessageType.FILE: + # Extract file from meta + if response.meta and "file" in response.meta: + file = response.meta["file"] + if isinstance(file, File): + # Check if file is for model or tool output + if response.meta.get("target") == "self": + # File is for model - add to files for next prompt + self.files.append(file) + response_content += f"File '{file.filename}' has been loaded into your context." + else: + # File is tool output + tool_files.append(file) + + return response_content, tool_files, None + + def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None: + """Validate tool arguments against the tool's required parameters. + + Checks that all required LLM-facing parameters are present and non-empty + before actual execution, preventing wasted tool invocations when the model + generates calls with missing arguments (e.g. empty ``{}``). + + Returns: + Error message if validation fails, None if all required parameters are satisfied. + """ + prompt_tool = tool_instance.to_prompt_message_tool() + required_params: list[str] = prompt_tool.parameters.get("required", []) + + if not required_params: + return None + + missing = [ + p + for p in required_params + if p not in tool_args + or tool_args[p] is None + or (isinstance(tool_args[p], str) and not tool_args[p].strip()) + ] + + if not missing: + return None + + return ( + f"Missing required parameter(s): {', '.join(missing)}. " + f"Please provide all required parameters before calling this tool." + ) + + def _find_tool_by_name(self, tool_name: str) -> Tool | None: + """Find a tool instance by its name.""" + for tool in self.tools: + if tool.entity.identity.name == tool_name: + return tool + return None + + def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]: + """Convert tools to prompt message format.""" + prompt_tools: list[PromptMessageTool] = [] + for tool in self.tools: + prompt_tools.append(tool.to_prompt_message_tool()) + return prompt_tools + + def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None: + """Initialize usage tracking with empty usage if not set.""" + if "usage" not in llm_usage or llm_usage["usage"] is None: + llm_usage["usage"] = LLMUsage.empty_usage() diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py new file mode 100644 index 0000000000..f4d91de8d2 --- /dev/null +++ b/api/core/agent/patterns/function_call.py @@ -0,0 +1,359 @@ +"""Function Call strategy implementation. + +Implements the Function Call agent pattern where the LLM uses native tool-calling +capability to invoke tools. Includes pre-execution parameter validation that +intercepts invalid calls (e.g. empty arguments) before they reach tool backends, +and avoids counting purely-invalid rounds against the iteration budget. +""" + +import json +import logging +from collections.abc import Generator +from typing import Any, Union + +from core.agent.entities import AgentLog, AgentResult +from core.tools.entities.tool_entities import ToolInvokeMeta +from graphon.file import File +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, +) + +from .base import AgentPattern + +logger = logging.getLogger(__name__) + + +class FunctionCallStrategy(AgentPattern): + """Function Call strategy using model's native tool calling capability.""" + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the function call agent strategy.""" + # Convert tools to prompt format + prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format() + + # Initialize tracking + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + function_call_state: bool = True + total_usage: dict[str, LLMUsage | None] = {"usage": None} + messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy + final_text: str = "" + finish_reason: str | None = None + output_files: list[File] = [] # Track files produced by tools + # Consecutive rounds where ALL tool calls failed parameter validation. + # When this happens the round is "free" (iteration_step not incremented) + # up to a safety cap to prevent infinite loops. + consecutive_validation_failures: int = 0 + max_validation_retries: int = 3 + + while function_call_state and iteration_step <= max_iterations: + function_call_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + yield round_log + # On last iteration, remove tools to force final answer + current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools + model_log = self._create_log( + label=f"{self.model_instance.model_name} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, LLMUsage | None] = {"usage": None} + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages, + model_parameters=model_parameters, + tools=current_tools, + stop=stop, + stream=stream, + user=self.context.user_id, + callbacks=[], + ) + + # Process response + tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log + ) + messages.append(self._create_assistant_message(response_content, tool_calls)) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update final text if no tool calls (this is likely the final answer) + if not tool_calls: + final_text = response_content + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Process tool calls + tool_outputs: dict[str, str] = {} + all_validation_errors: bool = True + if tool_calls: + function_call_state = True + # Execute tools (with pre-execution parameter validation) + for tool_call_id, tool_name, tool_args in tool_calls: + tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call( + tool_name, tool_args, tool_call_id, messages, round_log + ) + tool_outputs[tool_name] = tool_response + output_files.extend(tool_files) + if not is_validation_error: + all_validation_errors = False + else: + all_validation_errors = False + + yield self._finish_log( + round_log, + data={ + "llm_result": response_content, + "tool_calls": [ + {"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls + ] + if tool_calls + else [], + "final_answer": final_text if not function_call_state else None, + }, + usage=round_usage.get("usage"), + ) + + # Skip iteration counter when every tool call in this round failed validation, + # giving the model a free retry — but cap retries to prevent infinite loops. + if tool_calls and all_validation_errors: + consecutive_validation_failures += 1 + if consecutive_validation_failures >= max_validation_retries: + logger.warning( + "Agent hit %d consecutive validation-only rounds, forcing iteration increment", + consecutive_validation_failures, + ) + iteration_step += 1 + consecutive_validation_failures = 0 + else: + logger.info( + "All tool calls failed validation (attempt %d/%d), not counting iteration", + consecutive_validation_failures, + max_validation_retries, + ) + else: + consecutive_validation_failures = 0 + iteration_step += 1 + + # Return final result + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, + files=output_files, + usage=total_usage.get("usage") or LLMUsage.empty_usage(), + finish_reason=finish_reason, + ) + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, LLMUsage | None], + start_log: AgentLog, + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[list[tuple[str, str, dict[str, Any]]], str, str | None], + ]: + """Handle LLM response chunks and extract tool calls and content. + + Returns a tuple of (tool_calls, response_content, finish_reason). + """ + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] + response_content: str = "" + finish_reason: str | None = None + if not isinstance(chunks, LLMResult): + # Streaming response + for chunk in chunks: + # Extract tool calls + if self._has_tool_calls(chunk): + tool_calls.extend(self._extract_tool_calls(chunk)) + + # Extract content + if chunk.delta.message and chunk.delta.message.content: + response_content += self._extract_content(chunk.delta.message.content) + + # Track usage + if chunk.delta.usage: + self._accumulate_usage(llm_usage, chunk.delta.usage) + + # Capture finish reason + if chunk.delta.finish_reason: + finish_reason = chunk.delta.finish_reason + + yield chunk + else: + # Non-streaming response + result: LLMResult = chunks + + if self._has_tool_calls_result(result): + tool_calls.extend(self._extract_tool_calls_result(result)) + + if result.message and result.message.content: + response_content += self._extract_content(result.message.content) + + if result.usage: + self._accumulate_usage(llm_usage, result.usage) + + # Convert to streaming format + yield LLMResultChunk( + model=result.model, + prompt_messages=result.prompt_messages, + delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), + ) + yield self._finish_log( + start_log, + data={ + "result": response_content, + }, + usage=llm_usage.get("usage"), + ) + return tool_calls, response_content, finish_reason + + def _create_assistant_message( + self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None + ) -> AssistantPromptMessage: + """Create assistant message with tool calls.""" + if tool_calls is None: + return AssistantPromptMessage(content=content) + return AssistantPromptMessage( + content=content or "", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=tc[0], + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])), + ) + for tc in tool_calls + ], + ) + + def _handle_tool_call( + self, + tool_name: str, + tool_args: dict[str, Any], + tool_call_id: str, + messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]: + """Handle a single tool call and return response with files, meta, and validation status. + + Validates required parameters before execution. When validation fails the tool + is never invoked — a synthetic error is fed back to the model so it can self-correct + without consuming a real iteration. + + Returns: + (response_content, tool_files, tool_invoke_meta, is_validation_error). + ``is_validation_error`` is True when the call was rejected due to missing + required parameters, allowing the caller to skip the iteration counter. + """ + # Find tool + tool_instance = self._find_tool_by_name(tool_name) + if not tool_instance: + raise ValueError(f"Tool {tool_name} not found") + + # Get tool metadata (provider, icon, etc.) + tool_metadata = self._get_tool_metadata(tool_instance) + + # Create tool call log + tool_call_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + extra_metadata=tool_metadata, + ) + yield tool_call_log + + # Validate required parameters before execution to avoid wasted invocations + validation_error = self._validate_tool_args(tool_instance, tool_args) + if validation_error: + tool_call_log.status = AgentLog.LogStatus.ERROR + tool_call_log.error = validation_error + tool_call_log.data = {**tool_call_log.data, "error": validation_error} + yield tool_call_log + + messages.append(ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name)) + return validation_error, [], None, True + + # Invoke tool using base class method with error handling + try: + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name) + + yield self._finish_log( + tool_call_log, + data={ + **tool_call_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + final_content = response_content or "Tool executed successfully" + # Add tool response to messages + messages.append( + ToolPromptMessage( + content=final_content, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + return response_content, tool_files, tool_invoke_meta, False + except Exception as e: + # Tool invocation failed, yield error log + error_message = str(e) + tool_call_log.status = AgentLog.LogStatus.ERROR + tool_call_log.error = error_message + tool_call_log.data = { + **tool_call_log.data, + "error": error_message, + } + yield tool_call_log + + # Add error message to conversation + error_content = f"Tool execution failed: {error_message}" + messages.append( + ToolPromptMessage( + content=error_content, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + return error_content, [], None, False diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py new file mode 100644 index 0000000000..179ccdb734 --- /dev/null +++ b/api/core/agent/patterns/react.py @@ -0,0 +1,419 @@ +"""ReAct strategy implementation.""" + +from __future__ import annotations + +import json +from collections.abc import Generator +from typing import TYPE_CHECKING, Any, Union + +from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext +from core.agent.output_parser.cot_output_parser import CotAgentOutputParser +from core.model_manager import ModelInstance +from graphon.file import File +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + PromptMessage, + SystemPromptMessage, +) + +from .base import AgentPattern, ToolInvokeHook + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class ReActStrategy(AgentPattern): + """ReAct strategy using reasoning and acting approach.""" + + def __init__( + self, + model_instance: ModelInstance, + tools: list[Tool], + context: ExecutionContext, + max_iterations: int = 10, + workflow_call_depth: int = 0, + files: list[File] = [], + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ): + """Initialize the ReAct strategy with instruction support.""" + super().__init__( + model_instance=model_instance, + tools=tools, + context=context, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + files=files, + tool_invoke_hook=tool_invoke_hook, + ) + self.instruction = instruction + + def run( + self, + prompt_messages: list[PromptMessage], + model_parameters: dict[str, Any], + stop: list[str] = [], + stream: bool = True, + ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: + """Execute the ReAct agent strategy.""" + # Initialize tracking + agent_scratchpad: list[AgentScratchpadUnit] = [] + iteration_step: int = 1 + max_iterations: int = self.max_iterations + 1 + react_state: bool = True + total_usage: dict[str, Any] = {"usage": None} + output_files: list[File] = [] # Track files produced by tools + final_text: str = "" + finish_reason: str | None = None + + # Add "Observation" to stop sequences + if "Observation" not in stop: + stop = stop.copy() + stop.append("Observation") + + while react_state and iteration_step <= max_iterations: + react_state = False + round_log = self._create_log( + label=f"ROUND {iteration_step}", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={}, + ) + yield round_log + + # Build prompt with/without tools based on iteration + include_tools = iteration_step < max_iterations + current_messages = self._build_prompt_with_react_format( + prompt_messages, agent_scratchpad, include_tools, self.instruction + ) + + model_log = self._create_log( + label=f"{self.model_instance.model_name} Thought", + log_type=AgentLog.LogType.THOUGHT, + status=AgentLog.LogStatus.START, + data={}, + parent_id=round_log.id, + extra_metadata={ + AgentLog.LogMetadata.PROVIDER: self.model_instance.provider, + }, + ) + yield model_log + + # Track usage for this round only + round_usage: dict[str, Any] = {"usage": None} + + # Use current messages directly (files are handled by base class if needed) + messages_to_use = current_messages + + # Invoke model + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( + prompt_messages=messages_to_use, + model_parameters=model_parameters, + stop=stop, + stream=stream, + user=self.context.user_id or "", + callbacks=[], + ) + + # Process response + scratchpad, chunk_finish_reason = yield from self._handle_chunks( + chunks, round_usage, model_log, current_messages + ) + agent_scratchpad.append(scratchpad) + + # Accumulate to total usage + round_usage_value = round_usage.get("usage") + if round_usage_value: + self._accumulate_usage(total_usage, round_usage_value) + + # Update finish reason + if chunk_finish_reason: + finish_reason = chunk_finish_reason + + # Check if we have an action to execute + if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": + react_state = True + # Execute tool + observation, tool_files = yield from self._handle_tool_call( + scratchpad.action, current_messages, round_log + ) + scratchpad.observation = observation + # Track files produced by tools + output_files.extend(tool_files) + + # Add observation to scratchpad for display + yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages) + else: + # Extract final answer + if scratchpad.action and scratchpad.action.action_input: + final_answer = scratchpad.action.action_input + if isinstance(final_answer, dict): + final_answer = json.dumps(final_answer, ensure_ascii=False) + final_text = str(final_answer) + elif scratchpad.thought: + # If no action but we have thought, use thought as final answer + final_text = scratchpad.thought + + yield self._finish_log( + round_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + "observation": scratchpad.observation or None, + "final_answer": final_text if not react_state else None, + }, + usage=round_usage.get("usage"), + ) + iteration_step += 1 + + # Return final result + + from core.agent.entities import AgentResult + + return AgentResult( + text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason + ) + + def _build_prompt_with_react_format( + self, + original_messages: list[PromptMessage], + agent_scratchpad: list[AgentScratchpadUnit], + include_tools: bool = True, + instruction: str = "", + ) -> list[PromptMessage]: + """Build prompt messages with ReAct format.""" + # Copy messages to avoid modifying original + messages = list(original_messages) + + # Find and update the system prompt that should already exist + system_prompt_found = False + for i, msg in enumerate(messages): + if isinstance(msg, SystemPromptMessage): + system_prompt_found = True + # The system prompt from frontend already has the template, just replace placeholders + + # Format tools + tools_str = "" + tool_names = [] + if include_tools and self.tools: + # Convert tools to prompt message tools format + prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools] + tool_names = [tool.name for tool in prompt_tools] + + # Format tools as JSON for comprehensive information + from graphon.model_runtime.utils.encoders import jsonable_encoder + + tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2) + tool_names_str = ", ".join(f'"{name}"' for name in tool_names) + else: + tools_str = "No tools available" + tool_names_str = "" + + # Replace placeholders in the existing system prompt + updated_content = msg.content + assert isinstance(updated_content, str) + updated_content = updated_content.replace("{{instruction}}", instruction) + updated_content = updated_content.replace("{{tools}}", tools_str) + updated_content = updated_content.replace("{{tool_names}}", tool_names_str) + + # Create new SystemPromptMessage with updated content + messages[i] = SystemPromptMessage(content=updated_content) + break + + # If no system prompt found, that's unexpected but add scratchpad anyway + if not system_prompt_found: + # This shouldn't happen if frontend is working correctly + pass + + # Format agent scratchpad + scratchpad_str = "" + if agent_scratchpad: + scratchpad_parts: list[str] = [] + for unit in agent_scratchpad: + if unit.thought: + scratchpad_parts.append(f"Thought: {unit.thought}") + if unit.action_str: + scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```") + if unit.observation: + scratchpad_parts.append(f"Observation: {unit.observation}") + scratchpad_str = "\n".join(scratchpad_parts) + + # If there's a scratchpad, append it to the last message + if scratchpad_str: + messages.append(AssistantPromptMessage(content=scratchpad_str)) + + return messages + + def _handle_chunks( + self, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], + llm_usage: dict[str, Any], + model_log: AgentLog, + current_messages: list[PromptMessage], + ) -> Generator[ + LLMResultChunk | AgentLog, + None, + tuple[AgentScratchpadUnit, str | None], + ]: + """Handle LLM response chunks and extract action/thought. + + Returns a tuple of (scratchpad_unit, finish_reason). + """ + usage_dict: dict[str, Any] = {} + + # Convert non-streaming to streaming format if needed + if isinstance(chunks, LLMResult): + result = chunks + + def result_to_chunks() -> Generator[LLMResultChunk, None, None]: + yield LLMResultChunk( + model=result.model, + prompt_messages=result.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=result.message, + usage=result.usage, + finish_reason=None, + ), + system_fingerprint=result.system_fingerprint or "", + ) + + streaming_chunks = result_to_chunks() + else: + streaming_chunks = chunks + + react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict) + + # Initialize scratchpad unit + scratchpad = AgentScratchpadUnit( + agent_response="", + thought="", + action_str="", + observation="", + action=None, + ) + + finish_reason: str | None = None + + # Process chunks + for chunk in react_chunks: + if isinstance(chunk, AgentScratchpadUnit.Action): + # Action detected + action_str = json.dumps(chunk.model_dump()) + scratchpad.agent_response = (scratchpad.agent_response or "") + action_str + scratchpad.action_str = action_str + scratchpad.action = chunk + + yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) + else: + # Text chunk + chunk_text = str(chunk) + scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text + scratchpad.thought = (scratchpad.thought or "") + chunk_text + + yield self._create_text_chunk(chunk_text, current_messages) + + # Update usage + if usage_dict.get("usage"): + if llm_usage.get("usage"): + self._accumulate_usage(llm_usage, usage_dict["usage"]) + else: + llm_usage["usage"] = usage_dict["usage"] + + # Clean up thought + scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you" + + # Finish model log + yield self._finish_log( + model_log, + data={ + "thought": scratchpad.thought, + "action": scratchpad.action_str if scratchpad.action else None, + }, + usage=llm_usage.get("usage"), + ) + + return scratchpad, finish_reason + + def _handle_tool_call( + self, + action: AgentScratchpadUnit.Action, + prompt_messages: list[PromptMessage], + round_log: AgentLog, + ) -> Generator[AgentLog, None, tuple[str, list[File]]]: + """Handle tool call and return observation with files.""" + tool_name = action.action_name + tool_args: dict[str, Any] | str = action.action_input + + # Find tool instance first to get metadata + tool_instance = self._find_tool_by_name(tool_name) + tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {} + + # Start tool log with tool metadata + tool_log = self._create_log( + label=f"CALL {tool_name}", + log_type=AgentLog.LogType.TOOL_CALL, + status=AgentLog.LogStatus.START, + data={ + "tool_name": tool_name, + "tool_args": tool_args, + }, + parent_id=round_log.id, + extra_metadata=tool_metadata, + ) + yield tool_log + + if not tool_instance: + # Finish tool log with error + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "error": f"Tool {tool_name} not found", + }, + ) + return f"Tool {tool_name} not found", [] + + # Ensure tool_args is a dict + tool_args_dict: dict[str, Any] + if isinstance(tool_args, str): + try: + tool_args_dict = json.loads(tool_args) + except json.JSONDecodeError: + tool_args_dict = {"input": tool_args} + elif not isinstance(tool_args, dict): + tool_args_dict = {"input": str(tool_args)} + else: + tool_args_dict = tool_args + + # Invoke tool using base class method with error handling + try: + response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name) + + # Finish tool log + yield self._finish_log( + tool_log, + data={ + **tool_log.data, + "output": response_content, + "files": len(tool_files), + "meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None, + }, + ) + + return response_content or "Tool executed successfully", tool_files + except Exception as e: + # Tool invocation failed, yield error log + error_message = str(e) + tool_log.status = AgentLog.LogStatus.ERROR + tool_log.error = error_message + tool_log.data = { + **tool_log.data, + "error": error_message, + } + yield tool_log + + return f"Tool execution failed: {error_message}", [] diff --git a/api/core/agent/patterns/strategy_factory.py b/api/core/agent/patterns/strategy_factory.py new file mode 100644 index 0000000000..8d718a3d27 --- /dev/null +++ b/api/core/agent/patterns/strategy_factory.py @@ -0,0 +1,108 @@ +"""Strategy factory for creating agent strategies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.agent.entities import AgentEntity, ExecutionContext +from core.model_manager import ModelInstance +from graphon.file.models import File +from graphon.model_runtime.entities.model_entities import ModelFeature + +from .base import AgentPattern, ToolInvokeHook +from .function_call import FunctionCallStrategy +from .react import ReActStrategy + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + + +class StrategyFactory: + """Factory for creating agent strategies based on model features.""" + + # Tool calling related features + TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL} + + @staticmethod + def create_strategy( + model_features: list[ModelFeature], + model_instance: ModelInstance, + context: ExecutionContext, + tools: list[Tool], + files: list[File], + max_iterations: int = 10, + workflow_call_depth: int = 0, + agent_strategy: AgentEntity.Strategy | None = None, + tool_invoke_hook: ToolInvokeHook | None = None, + instruction: str = "", + ) -> AgentPattern: + """ + Create an appropriate strategy based on model features. + + Args: + model_features: List of model features/capabilities + model_instance: Model instance to use + context: Execution context containing trace/audit information + tools: Available tools + files: Available files + max_iterations: Maximum iterations for the strategy + workflow_call_depth: Depth of workflow calls + agent_strategy: Optional explicit strategy override + tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke) + instruction: Optional instruction for ReAct strategy + + Returns: + AgentStrategy instance + """ + + # If explicit strategy is provided and it's Function Calling, try to use it if supported + if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING: + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + # Fallback to ReAct if FC is requested but not supported + + # If explicit strategy is Chain of Thought (ReAct) + if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Default auto-selection logic + if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES: + # Model supports native function calling + return FunctionCallStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + ) + else: + # Use ReAct strategy for models without function calling + return ReActStrategy( + model_instance=model_instance, + context=context, + tools=tools, + files=files, + max_iterations=max_iterations, + workflow_call_depth=workflow_call_depth, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 7bb2cdb876..eb7f264c59 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: # pragma: no cover from models.model import File +from graphon.model_runtime.entities import PromptMessageTool + from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ToolEntity, @@ -154,6 +156,61 @@ class Tool(ABC): return parameters + def to_prompt_message_tool(self) -> PromptMessageTool: + """Convert this tool to a PromptMessageTool for LLM consumption.""" + message_tool = PromptMessageTool( + name=self.entity.identity.name, + description=self.entity.description.llm if self.entity.description else "", + parameters={ + "type": "object", + "properties": {}, + "required": [], + }, + ) + + parameters = self.get_merged_runtime_parameters() + for parameter in parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + if parameter.type == ToolParameter.ToolParameterType.FILE: + file_format_desc = " Input the file id with format: [File: file_id]." + else: + file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. " + + message_tool.parameters["properties"][parameter.name] = { + "type": "string", + "description": (parameter.llm_description or "") + file_format_desc, + } + continue + + enum = [] + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] if parameter.options else [] + + message_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) + + if len(enum) > 0: + message_tool.parameters["properties"][parameter.name]["enum"] = enum + + if parameter.required: + message_tool.parameters["required"].append(parameter.name) + + return message_tool + def create_image_message( self, image: str, diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index f6c3aee4c1..f088d7ae00 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -52,6 +52,9 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyResolver, ) from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE +from core.workflow.nodes.agent_v2.event_adapter import AgentV2EventAdapter +from core.workflow.nodes.agent_v2.tool_manager import AgentV2ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db @@ -394,6 +397,13 @@ class DifyNodeFactory(NodeFactory): "runtime_support": self._agent_runtime_support, "message_transformer": self._agent_message_transformer, }, + AGENT_V2_NODE_TYPE: lambda: { + "tool_manager": AgentV2ToolManager( + tenant_id=self._dify_context.tenant_id, + app_id=self._dify_context.app_id, + ), + "event_adapter": AgentV2EventAdapter(), + }, } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( diff --git a/api/core/workflow/nodes/agent_v2/__init__.py b/api/core/workflow/nodes/agent_v2/__init__.py new file mode 100644 index 0000000000..a4024a99bd --- /dev/null +++ b/api/core/workflow/nodes/agent_v2/__init__.py @@ -0,0 +1,4 @@ +from .entities import AgentV2NodeData +from .node import AgentV2Node + +__all__ = ["AgentV2Node", "AgentV2NodeData"] diff --git a/api/core/workflow/nodes/agent_v2/entities.py b/api/core/workflow/nodes/agent_v2/entities.py new file mode 100644 index 0000000000..1535d23bdb --- /dev/null +++ b/api/core/workflow/nodes/agent_v2/entities.py @@ -0,0 +1,86 @@ +"""Agent V2 Node data model. + +Merges LLM Node capabilities (prompt, memory, vision, context, structured output) +with Agent capabilities (tool calling loop, strategy selection). +When no tools are configured, behaves identically to an LLM Node. +""" + +from collections.abc import Mapping, Sequence +from typing import Any, Literal + +from graphon.entities.base_node_data import BaseNodeData +from graphon.model_runtime.entities import ImagePromptMessageContent +from graphon.nodes.llm.entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + ModelConfig, + PromptConfig, +) +from pydantic import BaseModel, Field, field_validator + +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.tools.entities.tool_entities import ToolProviderType + +AGENT_V2_NODE_TYPE = "agent-v2" + + +class ContextConfig(BaseModel): + enabled: bool + variable_selector: list[str] | None = None + + +class VisionConfigOptions(BaseModel): + variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) + detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH + + +class VisionConfig(BaseModel): + enabled: bool = False + configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) + + @field_validator("configs", mode="before") + @classmethod + def convert_none_configs(cls, v: Any): + if v is None: + return VisionConfigOptions() + return v + + +class ToolMetadata(BaseModel): + """Tool configuration for Agent V2 node.""" + + enabled: bool = True + type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow") + provider_name: str = Field(..., description="Tool provider name/identifier") + tool_name: str = Field(..., description="Tool name") + plugin_unique_identifier: str | None = Field(None) + credential_id: str | None = Field(None) + parameters: dict[str, Any] = Field(default_factory=dict) + settings: dict[str, Any] = Field(default_factory=dict) + extra: dict[str, Any] = Field(default_factory=dict) + + +class AgentV2NodeData(BaseNodeData): + """Agent V2 Node — LLM + Agent capabilities in a single workflow node.""" + + type: str = AGENT_V2_NODE_TYPE + + # --- LLM capabilities (superset of LLMNodeData) --- + model: ModelConfig + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + prompt_config: PromptConfig = Field(default_factory=PromptConfig) + memory: MemoryConfig | None = None + context: ContextConfig = Field(default_factory=lambda: ContextConfig(enabled=False)) + vision: VisionConfig = Field(default_factory=VisionConfig) + structured_output: Mapping[str, Any] | None = None + structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") + reasoning_format: Literal["separated", "tagged"] = "tagged" + + # --- Agent capabilities --- + tools: Sequence[ToolMetadata] = Field(default_factory=list) + max_iterations: int = Field(default=10, ge=1, le=99) + agent_strategy: Literal["auto", "function-calling", "chain-of-thought"] = "auto" + + @property + def tool_call_enabled(self) -> bool: + return bool(self.tools) and any(t.enabled for t in self.tools) diff --git a/api/core/workflow/nodes/agent_v2/event_adapter.py b/api/core/workflow/nodes/agent_v2/event_adapter.py new file mode 100644 index 0000000000..15d5969c43 --- /dev/null +++ b/api/core/workflow/nodes/agent_v2/event_adapter.py @@ -0,0 +1,96 @@ +"""Event adapter for Agent V2 Node. + +Converts AgentPattern outputs (LLMResultChunk | AgentLog) into +graphon NodeEventBase events consumable by the workflow engine. +""" + +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +from graphon.model_runtime.entities import LLMResultChunk +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import ( + AgentLogEvent, + ModelInvokeCompletedEvent, + NodeEventBase, + StreamChunkEvent, +) + +from core.agent.entities import AgentLog, AgentResult + + +class AgentV2EventAdapter: + """Converts agent strategy outputs into workflow node events.""" + + def process_strategy_outputs( + self, + outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult], + *, + node_id: str, + node_execution_id: str, + ) -> Generator[NodeEventBase, None, AgentResult]: + """Process strategy generator outputs, yielding node events. + + Returns the final AgentResult from the strategy. + """ + try: + while True: + item = next(outputs) + if isinstance(item, AgentLog): + yield self._convert_agent_log(item, node_id=node_id, node_execution_id=node_execution_id) + elif isinstance(item, LLMResultChunk): + yield from self._convert_llm_chunk(item, node_id=node_id) + except StopIteration as e: + result: AgentResult = e.value + if result.usage: + usage = result.usage if isinstance(result.usage, LLMUsage) else LLMUsage.empty_usage() + yield ModelInvokeCompletedEvent( + text=result.text, + usage=usage, + finish_reason=result.finish_reason, + ) + return result + + def _convert_agent_log( + self, + log: AgentLog, + *, + node_id: str, + node_execution_id: str, + ) -> AgentLogEvent: + return AgentLogEvent( + message_id=log.id, + label=log.label, + node_execution_id=node_execution_id, + parent_id=log.parent_id, + error=log.error, + status=log.status.value, + data=dict(log.data), + metadata={k.value if hasattr(k, "value") else str(k): v for k, v in log.metadata.items()}, + node_id=node_id, + ) + + def _convert_llm_chunk( + self, + chunk: LLMResultChunk, + *, + node_id: str, + ) -> Generator[NodeEventBase, None, None]: + content = "" + if chunk.delta.message and chunk.delta.message.content: + if isinstance(chunk.delta.message.content, str): + content = chunk.delta.message.content + elif isinstance(chunk.delta.message.content, list): + from graphon.model_runtime.entities.message_entities import TextPromptMessageContent + + for item in chunk.delta.message.content: + if isinstance(item, TextPromptMessageContent): + content += item.data + + if content: + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=content, + ) diff --git a/api/core/workflow/nodes/agent_v2/node.py b/api/core/workflow/nodes/agent_v2/node.py new file mode 100644 index 0000000000..02714129ad --- /dev/null +++ b/api/core/workflow/nodes/agent_v2/node.py @@ -0,0 +1,370 @@ +"""Agent V2 Workflow Node. + +A unified workflow node that combines LLM capabilities with agent tool-calling. +When tools are configured, runs an FC/ReAct loop via StrategyFactory. +When no tools are present, behaves as a single-shot LLM invocation. +""" + +from __future__ import annotations + +import logging +import re +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, cast + +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessageContentUnionTypes, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.node_events import ( + ModelInvokeCompletedEvent, + NodeEventBase, + NodeRunResult, + StreamChunkEvent, + StreamCompletedEvent, +) +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + +from core.agent.entities import AgentEntity, ExecutionContext +from core.agent.patterns import StrategyFactory +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.model_manager import ModelInstance, ModelManager +from core.workflow.system_variables import SystemVariableKey, get_system_text + +from .entities import AGENT_V2_NODE_TYPE, AgentV2NodeData +from .event_adapter import AgentV2EventAdapter +from .tool_manager import AgentV2ToolManager + +if TYPE_CHECKING: + from graphon.entities import GraphInitParams + from graphon.entities.graph_config import NodeConfigDict + from graphon.runtime import GraphRuntimeState + +logger = logging.getLogger(__name__) + +_THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) + + +class AgentV2Node(Node[AgentV2NodeData]): + node_type = AGENT_V2_NODE_TYPE + + _tool_manager: AgentV2ToolManager + _event_adapter: AgentV2EventAdapter + + def __init__( + self, + id: str, + config: NodeConfigDict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + *, + tool_manager: AgentV2ToolManager, + event_adapter: AgentV2EventAdapter, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._tool_manager = tool_manager + self._event_adapter = event_adapter + + @classmethod + def version(cls) -> str: + return "1" + + def _run(self) -> Generator[NodeEventBase, None, None]: + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) + + try: + model_instance = self._fetch_model_instance(dify_ctx) + except Exception as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error=f"Failed to load model: {e}", + ) + ) + return + + prompt_messages = self._build_prompt_messages(dify_ctx) + + if self.node_data.tool_call_enabled: + yield from self._run_with_tools(model_instance, prompt_messages, dify_ctx) + else: + yield from self._run_without_tools(model_instance, prompt_messages, dify_ctx) + + # ------------------------------------------------------------------ + # No-tools path: single LLM invocation (LLM Node equivalent) + # ------------------------------------------------------------------ + + def _run_without_tools( + self, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + dify_ctx: DifyRunContext, + ) -> Generator[NodeEventBase, None, None]: + try: + result_chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=self.node_data.model.completion_params, + tools=[], + stop=[], + stream=True, + user=dify_ctx.user_id, + callbacks=[], + ) + + full_text = "" + reasoning_content = "" + usage: LLMUsage | None = None + finish_reason: str | None = None + + for chunk in result_chunks: + chunk_text = self._extract_chunk_text(chunk) + if chunk_text: + full_text += chunk_text + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=chunk_text, + ) + + if chunk.delta.usage: + usage = chunk.delta.usage + if chunk.delta.finish_reason: + finish_reason = chunk.delta.finish_reason + + if self.node_data.reasoning_format == "separated": + full_text, reasoning_content = self._separate_reasoning(full_text) + + if usage: + yield ModelInvokeCompletedEvent( + text=full_text, + usage=usage, + finish_reason=finish_reason, + reasoning_content=reasoning_content or None, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"prompt_messages": [m.model_dump() for m in prompt_messages]}, + outputs={ + "text": full_text, + "reasoning_content": reasoning_content, + "usage": usage.model_dump() if usage else {}, + "finish_reason": finish_reason or "stop", + }, + ) + ) + except Exception as e: + logger.exception("Agent V2 LLM invocation failed") + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error=str(e), + ) + ) + + # ------------------------------------------------------------------ + # Tools path: agent loop via StrategyFactory + # ------------------------------------------------------------------ + + def _run_with_tools( + self, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + dify_ctx: DifyRunContext, + ) -> Generator[NodeEventBase, None, None]: + try: + tool_instances = self._tool_manager.prepare_tool_instances( + list(self.node_data.tools), + ) + + model_features = self._get_model_features(model_instance) + + context = ExecutionContext( + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + tenant_id=dify_ctx.tenant_id, + conversation_id=get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.CONVERSATION_ID, + ), + ) + + agent_strategy_enum = self._map_strategy_config(self.node_data.agent_strategy) + + strategy = StrategyFactory.create_strategy( + model_features=model_features, + model_instance=model_instance, + tools=tool_instances, + files=[], + max_iterations=self.node_data.max_iterations, + context=context, + agent_strategy=agent_strategy_enum, + tool_invoke_hook=self._tool_manager.create_workflow_tool_invoke_hook(context), + ) + + outputs_gen = strategy.run( + prompt_messages=prompt_messages, + model_parameters=self.node_data.model.completion_params, + stop=[], + stream=True, + ) + + result = yield from self._event_adapter.process_strategy_outputs( + outputs_gen, + node_id=self._node_id, + node_execution_id=self.id, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"prompt_messages": [m.model_dump() for m in prompt_messages]}, + outputs={ + "text": result.text, + "files": [f.model_dump() if hasattr(f, "model_dump") else str(f) for f in result.files], + "usage": result.usage.model_dump() if hasattr(result.usage, "model_dump") else {}, + "finish_reason": result.finish_reason or "stop", + }, + ) + ) + except Exception as e: + logger.exception("Agent V2 tool execution failed") + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error=str(e), + ) + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _fetch_model_instance(self, dify_ctx: DifyRunContext) -> ModelInstance: + model_config = self.node_data.model + model_instance = ModelManager().get_model_instance( + tenant_id=dify_ctx.tenant_id, + provider=model_config.provider, + model_type=ModelType.LLM, + model=model_config.name, + ) + return model_instance + + def _build_prompt_messages(self, dify_ctx: DifyRunContext) -> list[PromptMessage]: + """Build prompt messages from the node's prompt_template, resolving variables.""" + variable_pool = self.graph_runtime_state.variable_pool + messages: list[PromptMessage] = [] + + template = self.node_data.prompt_template + if isinstance(template, Sequence) and not isinstance(template, str): + for msg_template in template: + role = msg_template.role.value if hasattr(msg_template.role, "value") else str(msg_template.role) + text = msg_template.text or "" + jinja2_text = getattr(msg_template, "jinja2_text", None) + content = jinja2_text or text + + resolved = VariableTemplateParser.resolve_template(content, variable_pool) + + if role == "system": + messages.append(SystemPromptMessage(content=resolved)) + elif role == "user": + messages.append(UserPromptMessage(content=resolved)) + elif role == "assistant": + messages.append(AssistantPromptMessage(content=resolved)) + else: + text_content = getattr(template, "text", "") or "" + resolved = VariableTemplateParser.resolve_template(text_content, variable_pool) + messages.append(UserPromptMessage(content=resolved)) + + return messages + + def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: + try: + model_schema = model_instance.model_type_instance.get_model_schema( + model_instance.model_name, + model_instance.credentials, + ) + return list(model_schema.features) if model_schema and model_schema.features else [] + except Exception: + logger.warning("Failed to get model features, assuming none") + return [] + + @staticmethod + def _map_strategy_config( + config_value: Literal["auto", "function-calling", "chain-of-thought"], + ) -> AgentEntity.Strategy | None: + mapping = { + "function-calling": AgentEntity.Strategy.FUNCTION_CALLING, + "chain-of-thought": AgentEntity.Strategy.CHAIN_OF_THOUGHT, + } + return mapping.get(config_value) + + @staticmethod + def _extract_chunk_text(chunk: LLMResultChunk) -> str: + if not chunk.delta.message or not chunk.delta.message.content: + return "" + content = chunk.delta.message.content + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, TextPromptMessageContent): + parts.append(item.data) + return "".join(parts) + return "" + + @staticmethod + def _separate_reasoning(text: str) -> tuple[str, str]: + """Extract blocks from text, return (clean_text, reasoning_content).""" + reasoning_parts = _THINK_PATTERN.findall(text) + reasoning_content = "\n".join(reasoning_parts) + clean_text = _THINK_PATTERN.sub("", text).strip() + return clean_text, reasoning_content + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AgentV2NodeData, + ) -> Mapping[str, Sequence[str]]: + result: dict[str, list[str]] = {} + + if isinstance(node_data.prompt_template, Sequence) and not isinstance(node_data.prompt_template, str): + for msg in node_data.prompt_template: + text = msg.text or "" + jinja2_text = getattr(msg, "jinja2_text", None) + content = jinja2_text or text + selectors = VariableTemplateParser(content).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + else: + text_content = getattr(node_data.prompt_template, "text", "") or "" + selectors = VariableTemplateParser(text_content).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + + return {f"{node_id}.{key}": value for key, value in result.items()} diff --git a/api/core/workflow/nodes/agent_v2/tool_manager.py b/api/core/workflow/nodes/agent_v2/tool_manager.py new file mode 100644 index 0000000000..8cc1409dd9 --- /dev/null +++ b/api/core/workflow/nodes/agent_v2/tool_manager.py @@ -0,0 +1,122 @@ +"""Tool management for Agent V2 Node. + +Handles tool instance preparation, conversion to LLM-consumable format, +and creation of workflow-compatible tool invoke hooks. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Generator +from typing import TYPE_CHECKING, Any + +from graphon.file import File +from graphon.model_runtime.entities import PromptMessageTool + +from core.agent.entities import AgentToolEntity, ExecutionContext +from core.agent.patterns.base import ToolInvokeHook +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeMeta, ToolInvokeMessage +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager + +if TYPE_CHECKING: + from .entities import ToolMetadata + +logger = logging.getLogger(__name__) + + +class AgentV2ToolManager: + """Manages tool lifecycle for Agent V2 node execution.""" + + def __init__( + self, + *, + tenant_id: str, + app_id: str, + ) -> None: + self._tenant_id = tenant_id + self._app_id = app_id + + def prepare_tool_instances( + self, + tools_config: list[ToolMetadata], + ) -> list[Tool]: + """Convert tool metadata configs into runtime Tool instances.""" + tool_instances: list[Tool] = [] + for tool_meta in tools_config: + if not tool_meta.enabled: + continue + try: + processed_settings = {} + for key, value in tool_meta.settings.items(): + if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict): + if "type" in value["value"] and "value" in value["value"]: + processed_settings[key] = value["value"] + else: + processed_settings[key] = value + else: + processed_settings[key] = value + + merged_parameters = {**tool_meta.parameters, **processed_settings} + + agent_tool = AgentToolEntity( + provider_id=tool_meta.provider_name, + provider_type=tool_meta.type, + tool_name=tool_meta.tool_name, + tool_parameters=merged_parameters, + plugin_unique_identifier=tool_meta.plugin_unique_identifier, + credential_id=tool_meta.credential_id, + ) + + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=self._tenant_id, + app_id=self._app_id, + agent_tool=agent_tool, + ) + tool_instances.append(tool_runtime) + except Exception: + logger.warning("Failed to prepare tool %s/%s, skipping", tool_meta.provider_name, tool_meta.tool_name, exc_info=True) + continue + + return tool_instances + + def create_workflow_tool_invoke_hook( + self, + context: ExecutionContext, + workflow_call_depth: int = 0, + ) -> ToolInvokeHook: + """Create a ToolInvokeHook for workflow context (uses generic_invoke).""" + + def hook( + tool: Tool, + tool_args: dict[str, Any], + tool_name: str, + ) -> tuple[str, list[str], ToolInvokeMeta]: + tool_response = ToolEngine.generic_invoke( + tool=tool, + tool_parameters=tool_args, + user_id=context.user_id or "", + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=workflow_call_depth, + app_id=context.app_id, + conversation_id=context.conversation_id, + ) + + response_content = "" + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(response.message, ToolInvokeMessage.TextMessage) + response_content += response.message.text + elif response.type == ToolInvokeMessage.MessageType.JSON: + if isinstance(response.message, ToolInvokeMessage.JsonMessage): + response_content += json.dumps(response.message.json_object, ensure_ascii=False) + elif response.type == ToolInvokeMessage.MessageType.LINK: + if isinstance(response.message, ToolInvokeMessage.TextMessage): + response_content += f"[Link: {response.message.text}]" + + return response_content, [], ToolInvokeMeta.empty() + + return hook diff --git a/api/models/model.py b/api/models/model.py index 43ddf344d2..359b570f5f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -352,6 +352,7 @@ class AppMode(StrEnum): CHAT = "chat" ADVANCED_CHAT = "advanced-chat" AGENT_CHAT = "agent-chat" + AGENT = "agent" CHANNEL = "channel" RAG_PIPELINE = "rag-pipeline" diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index dd73e10374..698c49e7d2 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -483,7 +483,7 @@ class AppDslService: ) # Initialize app based on mode - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}: workflow_data = data.get("workflow") if not workflow_data or not isinstance(workflow_data, dict): raise ValueError("Missing workflow data for workflow/advanced chat app") @@ -566,7 +566,7 @@ class AppDslService: }, } - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}: cls._append_workflow_export_data( export_data=export_data, app_model=app_model, include_secret=include_secret, workflow_id=workflow_id ) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 17ed98d301..9d87ce4b9a 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -147,6 +147,54 @@ class AppGenerateService: ), request_id=request_id, ) + case AppMode.AGENT: + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) + + if streaming: + with rate_limit_context(rate_limit, request_id): + payload = AppExecutionParams.new( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=True, + call_depth=0, + ) + payload_json = payload.model_dump_json() + + def on_subscribe(): + workflow_based_app_execution_task.delay(payload_json) + + on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + generator = AdvancedChatAppGenerator() + return rate_limit.generate( + generator.convert_to_event_stream( + generator.retrieve_events( + AppMode.ADVANCED_CHAT, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), + ), + request_id=request_id, + ) + else: + advanced_generator = AdvancedChatAppGenerator() + return rate_limit.generate( + advanced_generator.convert_to_event_stream( + advanced_generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + workflow_run_id=str(uuid.uuid4()), + streaming=False, + ) + ), + request_id=request_id, + ) case AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) diff --git a/api/services/app_service.py b/api/services/app_service.py index 87d52a3159..6cabd57b67 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -10,6 +10,7 @@ from sqlalchemy import select from configs import dify_config from constants.model_template import default_app_templates +from services.workflow.graph_factory import WorkflowGraphFactory from core.agent.entities import AgentToolEntity from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager @@ -169,6 +170,9 @@ class AppService: db.session.commit() + if app_mode == AppMode.AGENT: + self._init_agent_workflow(app, account, default_model_dict if default_model_config else None) + app_was_created.send(app, account=account) if FeatureService.get_system_features().webapp_auth.enabled: @@ -180,6 +184,34 @@ class AppService: return app + @staticmethod + def _init_agent_workflow(app: App, account: Any, model_dict: dict | None) -> None: + """Create the default single-agent-node workflow for a new Agent app.""" + from services.workflow_service import WorkflowService + + model_config = model_dict or { + "provider": "openai", + "name": "gpt-4o", + "mode": "chat", + "completion_params": {}, + } + + graph = WorkflowGraphFactory.create_single_agent_graph( + model_config=model_config, + is_chat=True, + ) + + workflow_service = WorkflowService() + workflow_service.sync_draft_workflow( + app_model=app, + graph=graph, + features={}, + unique_hash=None, + account=account, + environment_variables=[], + conversation_variables=[], + ) + def get_app(self, app: App) -> App: """ Get App diff --git a/api/services/workflow/graph_factory.py b/api/services/workflow/graph_factory.py new file mode 100644 index 0000000000..abaf69ddb4 --- /dev/null +++ b/api/services/workflow/graph_factory.py @@ -0,0 +1,113 @@ +"""Factory for programmatically building workflow graphs. + +Used by AppService to auto-generate single-node workflow graphs when +creating a new Agent app (AppMode.AGENT). +""" + +from typing import Any + +from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE + + +class WorkflowGraphFactory: + """Builds workflow graph dicts for special app creation flows.""" + + @staticmethod + def create_single_agent_graph( + model_config: dict[str, Any], + is_chat: bool = True, + ) -> dict[str, Any]: + """Create a minimal start -> agent_v2 -> answer/end graph. + + Args: + model_config: Model configuration dict with provider, name, mode, completion_params. + is_chat: If True, creates chatflow (with answer node); otherwise workflow (with end node). + + Returns: + Graph dict with nodes and edges, ready for WorkflowService.sync_draft_workflow(). + """ + agent_node_data: dict[str, Any] = { + "type": AGENT_V2_NODE_TYPE, + "title": "Agent", + "model": model_config, + "prompt_template": [ + {"role": "system", "text": "You are a helpful assistant."}, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "tools": [], + "max_iterations": 10, + "agent_strategy": "auto", + "context": {"enabled": False}, + "vision": {"enabled": False}, + } + + if is_chat: + agent_node_data["memory"] = {"window": {"enabled": True, "size": 50}} + + nodes: list[dict[str, Any]] = [ + { + "id": "start", + "type": "custom", + "data": {"type": "start", "title": "Start", "variables": []}, + "position": {"x": 80, "y": 282}, + }, + { + "id": "agent", + "type": "custom", + "data": agent_node_data, + "position": {"x": 400, "y": 282}, + }, + ] + + if is_chat: + nodes.append( + { + "id": "answer", + "type": "custom", + "data": { + "type": "answer", + "title": "Answer", + "answer": "{{#agent.text#}}", + }, + "position": {"x": 720, "y": 282}, + } + ) + end_node_id = "answer" + else: + nodes.append( + { + "id": "end", + "type": "custom", + "data": { + "type": "end", + "title": "End", + "outputs": [ + { + "value_selector": ["agent", "text"], + "variable": "result", + } + ], + }, + "position": {"x": 720, "y": 282}, + } + ) + end_node_id = "end" + + edges: list[dict[str, str]] = [ + { + "id": "start-agent", + "source": "start", + "target": "agent", + "sourceHandle": "source", + "targetHandle": "target", + }, + { + "id": f"agent-{end_node_id}", + "source": "agent", + "target": end_node_id, + "sourceHandle": "source", + "targetHandle": "target", + }, + ] + + return {"nodes": nodes, "edges": edges} diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/__init__.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_basic.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_basic.py new file mode 100644 index 0000000000..092550e278 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_basic.py @@ -0,0 +1,332 @@ +"""Basic tests for Agent V2 node — Phase 1 + 2 validation. + +Tests: +1. Module imports resolve without errors +2. AgentV2Node self-registers in the graphon Node registry +3. DifyNodeFactory kwargs mapping includes agent-v2 +4. StrategyFactory selects correct strategy based on model features +5. AgentV2NodeData validates with and without tools +""" + +import pytest + + +class TestPhase1Imports: + """Verify Phase 1 (Agent Patterns) modules import correctly.""" + + def test_entities_import(self): + from core.agent.entities import AgentLog, AgentResult, ExecutionContext + + assert ExecutionContext is not None + assert AgentLog is not None + assert AgentResult is not None + + def test_entities_backward_compatible(self): + from core.agent.entities import ( + AgentEntity, + AgentInvokeMessage, + AgentPromptEntity, + AgentScratchpadUnit, + AgentToolEntity, + ) + + assert AgentEntity is not None + assert AgentToolEntity is not None + assert AgentPromptEntity is not None + assert AgentScratchpadUnit is not None + assert AgentInvokeMessage is not None + + def test_patterns_module_import(self): + from core.agent.patterns import ( + AgentPattern, + FunctionCallStrategy, + ReActStrategy, + StrategyFactory, + ) + + assert AgentPattern is not None + assert FunctionCallStrategy is not None + assert ReActStrategy is not None + assert StrategyFactory is not None + + def test_patterns_inheritance(self): + from core.agent.patterns import AgentPattern, FunctionCallStrategy, ReActStrategy + + assert issubclass(FunctionCallStrategy, AgentPattern) + assert issubclass(ReActStrategy, AgentPattern) + + +class TestPhase2Imports: + """Verify Phase 2 (Agent V2 Node) modules import correctly.""" + + def test_entities_import(self): + from core.workflow.nodes.agent_v2.entities import ( + AGENT_V2_NODE_TYPE, + AgentV2NodeData, + ContextConfig, + ToolMetadata, + VisionConfig, + ) + + assert AGENT_V2_NODE_TYPE == "agent-v2" + assert AgentV2NodeData is not None + assert ToolMetadata is not None + + def test_node_import(self): + from core.workflow.nodes.agent_v2.node import AgentV2Node + + assert AgentV2Node is not None + assert AgentV2Node.node_type == "agent-v2" + + def test_tool_manager_import(self): + from core.workflow.nodes.agent_v2.tool_manager import AgentV2ToolManager + + assert AgentV2ToolManager is not None + + def test_event_adapter_import(self): + from core.workflow.nodes.agent_v2.event_adapter import AgentV2EventAdapter + + assert AgentV2EventAdapter is not None + + +class TestNodeRegistration: + """Verify AgentV2Node self-registers in the graphon Node registry.""" + + def test_agent_v2_in_registry(self): + from core.workflow.node_factory import register_nodes + + register_nodes() + + from graphon.nodes.base.node import Node + + registry = Node.get_node_type_classes_mapping() + assert "agent-v2" in registry, f"agent-v2 not found in registry. Available: {list(registry.keys())}" + + def test_agent_v2_latest_version(self): + from core.workflow.node_factory import register_nodes + + register_nodes() + + from graphon.nodes.base.node import Node + + registry = Node.get_node_type_classes_mapping() + agent_v2_versions = registry.get("agent-v2", {}) + assert "latest" in agent_v2_versions + assert "1" in agent_v2_versions + + from core.workflow.nodes.agent_v2.node import AgentV2Node + + assert agent_v2_versions["latest"] is AgentV2Node + assert agent_v2_versions["1"] is AgentV2Node + + def test_old_agent_still_registered(self): + """Old Agent node must not be affected by Agent V2.""" + from core.workflow.node_factory import register_nodes + + register_nodes() + + from graphon.nodes.base.node import Node + + registry = Node.get_node_type_classes_mapping() + assert "agent" in registry, "Old agent node must still be registered" + + def test_resolve_workflow_node_class(self): + from core.workflow.node_factory import register_nodes, resolve_workflow_node_class + from core.workflow.nodes.agent_v2.node import AgentV2Node + + register_nodes() + + resolved = resolve_workflow_node_class(node_type="agent-v2", node_version="1") + assert resolved is AgentV2Node + + resolved_latest = resolve_workflow_node_class(node_type="agent-v2", node_version="latest") + assert resolved_latest is AgentV2Node + + +class TestNodeFactoryKwargs: + """Verify DifyNodeFactory includes agent-v2 in kwargs mapping.""" + + def test_agent_v2_node_type_in_factory(self): + from core.workflow.node_factory import AGENT_V2_NODE_TYPE + + assert AGENT_V2_NODE_TYPE == "agent-v2" + + +class TestStrategyFactory: + """Verify StrategyFactory selects correct strategy.""" + + def test_fc_selected_for_tool_call_model(self): + from graphon.model_runtime.entities.model_entities import ModelFeature + + from core.agent.patterns import FunctionCallStrategy, StrategyFactory + + assert ModelFeature.TOOL_CALL in StrategyFactory.TOOL_CALL_FEATURES + assert ModelFeature.MULTI_TOOL_CALL in StrategyFactory.TOOL_CALL_FEATURES + + def test_factory_has_create_strategy(self): + from core.agent.patterns import StrategyFactory + + assert callable(getattr(StrategyFactory, "create_strategy", None)) + + +class TestAgentV2NodeData: + """Verify AgentV2NodeData validation.""" + + def test_minimal_data(self): + from core.workflow.nodes.agent_v2.entities import AgentV2NodeData + + data = AgentV2NodeData( + title="Test Agent", + model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}, + prompt_template=[{"role": "system", "text": "You are helpful."}, {"role": "user", "text": "Hello"}], + context={"enabled": False}, + ) + assert data.type == "agent-v2" + assert data.tool_call_enabled is False + assert data.max_iterations == 10 + assert data.agent_strategy == "auto" + + def test_data_with_tools(self): + from core.workflow.nodes.agent_v2.entities import AgentV2NodeData + + data = AgentV2NodeData( + title="Test Agent with Tools", + model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}, + prompt_template=[{"role": "user", "text": "Search for {{query}}"}], + context={"enabled": False}, + tools=[ + { + "enabled": True, + "type": "builtin", + "provider_name": "google", + "tool_name": "google_search", + } + ], + max_iterations=5, + agent_strategy="function-calling", + ) + assert data.tool_call_enabled is True + assert data.max_iterations == 5 + assert data.agent_strategy == "function-calling" + assert len(data.tools) == 1 + + def test_data_with_disabled_tools(self): + from core.workflow.nodes.agent_v2.entities import AgentV2NodeData + + data = AgentV2NodeData( + title="Test Agent", + model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}, + prompt_template=[{"role": "user", "text": "Hello"}], + context={"enabled": False}, + tools=[ + { + "enabled": False, + "type": "builtin", + "provider_name": "google", + "tool_name": "google_search", + } + ], + ) + assert data.tool_call_enabled is False + + def test_data_with_memory(self): + from core.workflow.nodes.agent_v2.entities import AgentV2NodeData + + data = AgentV2NodeData( + title="Test Agent", + model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}, + prompt_template=[{"role": "user", "text": "Hello"}], + context={"enabled": False}, + memory={"window": {"enabled": True, "size": 50}}, + ) + assert data.memory is not None + assert data.memory.window.enabled is True + assert data.memory.window.size == 50 + + def test_data_with_vision(self): + from core.workflow.nodes.agent_v2.entities import AgentV2NodeData + + data = AgentV2NodeData( + title="Test Agent", + model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}, + prompt_template=[{"role": "user", "text": "Hello"}], + context={"enabled": False}, + vision={"enabled": True}, + ) + assert data.vision.enabled is True + + +class TestExecutionContext: + """Verify ExecutionContext entity.""" + + def test_create_minimal(self): + from core.agent.entities import ExecutionContext + + ctx = ExecutionContext.create_minimal(user_id="user-123") + assert ctx.user_id == "user-123" + assert ctx.app_id is None + + def test_to_dict(self): + from core.agent.entities import ExecutionContext + + ctx = ExecutionContext(user_id="u1", app_id="a1", tenant_id="t1") + d = ctx.to_dict() + assert d["user_id"] == "u1" + assert d["app_id"] == "a1" + assert d["tenant_id"] == "t1" + assert d["conversation_id"] is None + + def test_with_updates(self): + from core.agent.entities import ExecutionContext + + ctx = ExecutionContext(user_id="u1") + ctx2 = ctx.with_updates(app_id="a1", conversation_id="c1") + assert ctx2.user_id == "u1" + assert ctx2.app_id == "a1" + assert ctx2.conversation_id == "c1" + + +class TestAgentLog: + """Verify AgentLog entity.""" + + def test_create_log(self): + from core.agent.entities import AgentLog + + log = AgentLog( + label="Round 1", + log_type=AgentLog.LogType.ROUND, + status=AgentLog.LogStatus.START, + data={"key": "value"}, + ) + assert log.id is not None + assert log.label == "Round 1" + assert log.log_type == "round" + assert log.status == "start" + assert log.parent_id is None + + def test_log_types(self): + from core.agent.entities import AgentLog + + assert AgentLog.LogType.ROUND == "round" + assert AgentLog.LogType.THOUGHT == "thought" + assert AgentLog.LogType.TOOL_CALL == "tool_call" + + +class TestAgentResult: + """Verify AgentResult entity.""" + + def test_default_result(self): + from core.agent.entities import AgentResult + + result = AgentResult() + assert result.text == "" + assert result.files == [] + assert result.usage is None + assert result.finish_reason is None + + def test_result_with_data(self): + from core.agent.entities import AgentResult + + result = AgentResult(text="Hello world", finish_reason="stop") + assert result.text == "Hello world" + assert result.finish_reason == "stop" diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_phase3.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_phase3.py new file mode 100644 index 0000000000..6dcd302bed --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_agent_v2_phase3.py @@ -0,0 +1,132 @@ +"""Tests for Phase 3 — Agent app type support.""" + +import pytest + + +class TestAppModeAgent: + """Verify AppMode.AGENT is properly defined.""" + + def test_agent_mode_exists(self): + from models.model import AppMode + + assert hasattr(AppMode, "AGENT") + assert AppMode.AGENT == "agent" + + def test_agent_mode_value_of(self): + from models.model import AppMode + + mode = AppMode.value_of("agent") + assert mode == AppMode.AGENT + + def test_all_original_modes_still_work(self): + from models.model import AppMode + + for val in ["completion", "workflow", "chat", "advanced-chat", "agent-chat", "channel", "rag-pipeline"]: + mode = AppMode.value_of(val) + assert mode.value == val + + +class TestDefaultAppTemplate: + """Verify AGENT template is defined.""" + + def test_agent_template_exists(self): + from constants.model_template import default_app_templates + from models.model import AppMode + + assert AppMode.AGENT in default_app_templates + template = default_app_templates[AppMode.AGENT] + assert template["app"]["mode"] == AppMode.AGENT + assert template["app"]["enable_site"] is True + assert "model_config" in template + + def test_all_original_templates_exist(self): + from constants.model_template import default_app_templates + from models.model import AppMode + + for mode in [AppMode.WORKFLOW, AppMode.COMPLETION, AppMode.CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT]: + assert mode in default_app_templates + + +class TestWorkflowGraphFactory: + """Verify WorkflowGraphFactory creates valid graphs.""" + + def test_create_chat_graph(self): + from services.workflow.graph_factory import WorkflowGraphFactory + + model_config = {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}} + graph = WorkflowGraphFactory.create_single_agent_graph(model_config, is_chat=True) + + assert "nodes" in graph + assert "edges" in graph + assert len(graph["nodes"]) == 3 + assert len(graph["edges"]) == 2 + + node_types = [n["data"]["type"] for n in graph["nodes"]] + assert "start" in node_types + assert "agent-v2" in node_types + assert "answer" in node_types + + agent_node = next(n for n in graph["nodes"] if n["data"]["type"] == "agent-v2") + assert agent_node["data"]["model"] == model_config + assert agent_node["data"]["memory"] is not None + + def test_create_workflow_graph(self): + from services.workflow.graph_factory import WorkflowGraphFactory + + model_config = {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}} + graph = WorkflowGraphFactory.create_single_agent_graph(model_config, is_chat=False) + + node_types = [n["data"]["type"] for n in graph["nodes"]] + assert "end" in node_types + assert "answer" not in node_types + + agent_node = next(n for n in graph["nodes"] if n["data"]["type"] == "agent-v2") + assert agent_node["data"].get("memory") is None + + def test_edge_connectivity(self): + from services.workflow.graph_factory import WorkflowGraphFactory + + graph = WorkflowGraphFactory.create_single_agent_graph( + {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}, + is_chat=True, + ) + + edges = graph["edges"] + sources = {e["source"] for e in edges} + targets = {e["target"] for e in edges} + assert "start" in sources + assert "agent" in sources + assert "agent" in targets + assert "answer" in targets + + +class TestConsoleAppController: + """Verify Console API allows 'agent' mode.""" + + def test_allow_create_app_modes(self): + from controllers.console.app.app import ALLOW_CREATE_APP_MODES + + assert "agent" in ALLOW_CREATE_APP_MODES + assert "chat" in ALLOW_CREATE_APP_MODES + assert "agent-chat" in ALLOW_CREATE_APP_MODES + + +class TestAppGenerateServiceHasAgentCase: + """Verify the generate() method has an AppMode.AGENT case.""" + + def test_generate_method_exists(self): + from services.app_generate_service import AppGenerateService + + assert hasattr(AppGenerateService, "generate") + + def test_agent_mode_import(self): + """Verify AppMode.AGENT can be used in match statement context.""" + from models.model import AppMode + + mode = AppMode.AGENT + match mode: + case AppMode.AGENT: + result = "agent" + case _: + result = "other" + assert result == "agent"