mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
feat(api): add Agent V2 node and new Agent app type (Phase 1-3)
Introduce a new unified Agent V2 workflow node that combines LLM capabilities with agent tool-calling loops, along with a new AppMode.AGENT for standalone agent apps backed by single-node workflows. Phase 1 — Agent Patterns: - Add core/agent/patterns/ module (AgentPattern, FunctionCallStrategy, ReActStrategy, StrategyFactory) ported from feat/support-agent-sandbox - Add ExecutionContext, AgentLog, AgentResult entities - Add Tool.to_prompt_message_tool() for LLM-consumable tool conversion Phase 2 — Agent V2 Workflow Node: - Add core/workflow/nodes/agent_v2/ (AgentV2Node, AgentV2NodeData, AgentV2ToolManager, AgentV2EventAdapter) - Register agent-v2 node type in DifyNodeFactory - No-tools path: single LLM call (LLM Node equivalent) - Tools path: FC/ReAct loop via StrategyFactory Phase 3 — Agent App Type: - Add AppMode.AGENT to model enum - Add WorkflowGraphFactory for auto-generating start->agent_v2->answer graphs - AppService.create_app() creates workflow draft for AGENT mode - AppGenerateService.generate() routes AGENT to AdvancedChatAppGenerator - Console API and DSL import/export support AGENT mode - Default app template for AGENT mode Old agent/agent-chat/LLM node paths are fully preserved. 38 unit tests all passing. Made-with: Cursor
This commit is contained in:
parent
b1adb5652e
commit
96641a93f6
@ -81,4 +81,20 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new agent backed by single-node workflow)
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -51,7 +51,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@ -61,7 +61,7 @@ _logger = logging.getLogger(__name__)
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@ -93,7 +93,9 @@ class AppListQuery(BaseModel):
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@ -92,3 +94,79 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
Carries IDs and metadata needed for tracing, auditing, and correlation
|
||||
but not part of the core business logic.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
return ExecutionContext(**{k: v for k, v in data.items() if k in ExecutionContext.model_fields})
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""Structured log entry for agent execution tracing."""
|
||||
|
||||
class LogType(StrEnum):
|
||||
ROUND = "round"
|
||||
THOUGHT = "thought"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
label: str = Field(...)
|
||||
log_type: LogType = Field(...)
|
||||
parent_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
status: LogStatus = Field(...)
|
||||
data: Mapping[str, Any] = Field(...)
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={})
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""Agent execution result."""
|
||||
|
||||
text: str = Field(default="")
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
usage: Any | None = Field(default=None)
|
||||
finish_reason: str | None = Field(default=None)
|
||||
|
||||
19
api/core/agent/patterns/__init__.py
Normal file
19
api/core/agent/patterns/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
506
api/core/agent/patterns/base.py
Normal file
506
api/core/agent/patterns/base.py
Normal file
@ -0,0 +1,506 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine.generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
|
||||
"""Validate tool arguments against the tool's required parameters.
|
||||
|
||||
Checks that all required LLM-facing parameters are present and non-empty
|
||||
before actual execution, preventing wasted tool invocations when the model
|
||||
generates calls with missing arguments (e.g. empty ``{}``).
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if all required parameters are satisfied.
|
||||
"""
|
||||
prompt_tool = tool_instance.to_prompt_message_tool()
|
||||
required_params: list[str] = prompt_tool.parameters.get("required", [])
|
||||
|
||||
if not required_params:
|
||||
return None
|
||||
|
||||
missing = [
|
||||
p
|
||||
for p in required_params
|
||||
if p not in tool_args
|
||||
or tool_args[p] is None
|
||||
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
|
||||
]
|
||||
|
||||
if not missing:
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Missing required parameter(s): {', '.join(missing)}. "
|
||||
f"Please provide all required parameters before calling this tool."
|
||||
)
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
359
api/core/agent/patterns/function_call.py
Normal file
359
api/core/agent/patterns/function_call.py
Normal file
@ -0,0 +1,359 @@
|
||||
"""Function Call strategy implementation.
|
||||
|
||||
Implements the Function Call agent pattern where the LLM uses native tool-calling
|
||||
capability to invoke tools. Includes pre-execution parameter validation that
|
||||
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
|
||||
and avoids counting purely-invalid rounds against the iteration budget.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
# Consecutive rounds where ALL tool calls failed parameter validation.
|
||||
# When this happens the round is "free" (iteration_step not incremented)
|
||||
# up to a safety cap to prevent infinite loops.
|
||||
consecutive_validation_failures: int = 0
|
||||
max_validation_retries: int = 3
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
all_validation_errors: bool = True
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools (with pre-execution parameter validation)
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
output_files.extend(tool_files)
|
||||
if not is_validation_error:
|
||||
all_validation_errors = False
|
||||
else:
|
||||
all_validation_errors = False
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
|
||||
# Skip iteration counter when every tool call in this round failed validation,
|
||||
# giving the model a free retry — but cap retries to prevent infinite loops.
|
||||
if tool_calls and all_validation_errors:
|
||||
consecutive_validation_failures += 1
|
||||
if consecutive_validation_failures >= max_validation_retries:
|
||||
logger.warning(
|
||||
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
|
||||
consecutive_validation_failures,
|
||||
)
|
||||
iteration_step += 1
|
||||
consecutive_validation_failures = 0
|
||||
else:
|
||||
logger.info(
|
||||
"All tool calls failed validation (attempt %d/%d), not counting iteration",
|
||||
consecutive_validation_failures,
|
||||
max_validation_retries,
|
||||
)
|
||||
else:
|
||||
consecutive_validation_failures = 0
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if not isinstance(chunks, LLMResult):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
|
||||
"""Handle a single tool call and return response with files, meta, and validation status.
|
||||
|
||||
Validates required parameters before execution. When validation fails the tool
|
||||
is never invoked — a synthetic error is fed back to the model so it can self-correct
|
||||
without consuming a real iteration.
|
||||
|
||||
Returns:
|
||||
(response_content, tool_files, tool_invoke_meta, is_validation_error).
|
||||
``is_validation_error`` is True when the call was rejected due to missing
|
||||
required parameters, allowing the caller to skip the iteration counter.
|
||||
"""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Validate required parameters before execution to avoid wasted invocations
|
||||
validation_error = self._validate_tool_args(tool_instance, tool_args)
|
||||
if validation_error:
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = validation_error
|
||||
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
|
||||
yield tool_call_log
|
||||
|
||||
messages.append(ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name))
|
||||
return validation_error, [], None, True
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta, False
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None, False
|
||||
419
api/core/agent/patterns/react.py
Normal file
419
api/core/agent/patterns/react.py
Normal file
@ -0,0 +1,419 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
finish_reason=None,
|
||||
),
|
||||
system_fingerprint=result.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
108
api/core/agent/patterns/strategy_factory.py
Normal file
108
api/core/agent/patterns/strategy_factory.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file.models import File
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Any
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from models.model import File
|
||||
|
||||
from graphon.model_runtime.entities import PromptMessageTool
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
@ -154,6 +156,61 @@ class Tool(ABC):
|
||||
|
||||
return parameters
|
||||
|
||||
def to_prompt_message_tool(self) -> PromptMessageTool:
|
||||
"""Convert this tool to a PromptMessageTool for LLM consumption."""
|
||||
message_tool = PromptMessageTool(
|
||||
name=self.entity.identity.name,
|
||||
description=self.entity.description.llm if self.entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
|
||||
parameters = self.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file_format_desc = " Input the file id with format: [File: file_id]."
|
||||
else:
|
||||
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": "string",
|
||||
"description": (parameter.llm_description or "") + file_format_desc,
|
||||
}
|
||||
continue
|
||||
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
|
||||
@ -52,6 +52,9 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
PluginAgentStrategyResolver,
|
||||
)
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE
|
||||
from core.workflow.nodes.agent_v2.event_adapter import AgentV2EventAdapter
|
||||
from core.workflow.nodes.agent_v2.tool_manager import AgentV2ToolManager
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
|
||||
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
|
||||
from extensions.ext_database import db
|
||||
@ -394,6 +397,13 @@ class DifyNodeFactory(NodeFactory):
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
},
|
||||
AGENT_V2_NODE_TYPE: lambda: {
|
||||
"tool_manager": AgentV2ToolManager(
|
||||
tenant_id=self._dify_context.tenant_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
),
|
||||
"event_adapter": AgentV2EventAdapter(),
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
|
||||
4
api/core/workflow/nodes/agent_v2/__init__.py
Normal file
4
api/core/workflow/nodes/agent_v2/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .entities import AgentV2NodeData
|
||||
from .node import AgentV2Node
|
||||
|
||||
__all__ = ["AgentV2Node", "AgentV2NodeData"]
|
||||
86
api/core/workflow/nodes/agent_v2/entities.py
Normal file
86
api/core/workflow/nodes/agent_v2/entities.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Agent V2 Node data model.
|
||||
|
||||
Merges LLM Node capabilities (prompt, memory, vision, context, structured output)
|
||||
with Agent capabilities (tool calling loop, strategy selection).
|
||||
When no tools are configured, behaves identically to an LLM Node.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.model_runtime.entities import ImagePromptMessageContent
|
||||
from graphon.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
ModelConfig,
|
||||
PromptConfig,
|
||||
)
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
AGENT_V2_NODE_TYPE = "agent-v2"
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
enabled: bool
|
||||
variable_selector: list[str] | None = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
|
||||
|
||||
@field_validator("configs", mode="before")
|
||||
@classmethod
|
||||
def convert_none_configs(cls, v: Any):
|
||||
if v is None:
|
||||
return VisionConfigOptions()
|
||||
return v
|
||||
|
||||
|
||||
class ToolMetadata(BaseModel):
|
||||
"""Tool configuration for Agent V2 node."""
|
||||
|
||||
enabled: bool = True
|
||||
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
|
||||
provider_name: str = Field(..., description="Tool provider name/identifier")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
plugin_unique_identifier: str | None = Field(None)
|
||||
credential_id: str | None = Field(None)
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
settings: dict[str, Any] = Field(default_factory=dict)
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentV2NodeData(BaseNodeData):
|
||||
"""Agent V2 Node — LLM + Agent capabilities in a single workflow node."""
|
||||
|
||||
type: str = AGENT_V2_NODE_TYPE
|
||||
|
||||
# --- LLM capabilities (superset of LLMNodeData) ---
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig = Field(default_factory=lambda: ContextConfig(enabled=False))
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged"
|
||||
|
||||
# --- Agent capabilities ---
|
||||
tools: Sequence[ToolMetadata] = Field(default_factory=list)
|
||||
max_iterations: int = Field(default=10, ge=1, le=99)
|
||||
agent_strategy: Literal["auto", "function-calling", "chain-of-thought"] = "auto"
|
||||
|
||||
@property
|
||||
def tool_call_enabled(self) -> bool:
|
||||
return bool(self.tools) and any(t.enabled for t in self.tools)
|
||||
96
api/core/workflow/nodes/agent_v2/event_adapter.py
Normal file
96
api/core/workflow/nodes/agent_v2/event_adapter.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""Event adapter for Agent V2 Node.
|
||||
|
||||
Converts AgentPattern outputs (LLMResultChunk | AgentLog) into
|
||||
graphon NodeEventBase events consumable by the workflow engine.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities import LLMResultChunk
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.node_events import (
|
||||
AgentLogEvent,
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEventBase,
|
||||
StreamChunkEvent,
|
||||
)
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
|
||||
|
||||
class AgentV2EventAdapter:
|
||||
"""Converts agent strategy outputs into workflow node events."""
|
||||
|
||||
def process_strategy_outputs(
|
||||
self,
|
||||
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
|
||||
*,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, AgentResult]:
|
||||
"""Process strategy generator outputs, yielding node events.
|
||||
|
||||
Returns the final AgentResult from the strategy.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
item = next(outputs)
|
||||
if isinstance(item, AgentLog):
|
||||
yield self._convert_agent_log(item, node_id=node_id, node_execution_id=node_execution_id)
|
||||
elif isinstance(item, LLMResultChunk):
|
||||
yield from self._convert_llm_chunk(item, node_id=node_id)
|
||||
except StopIteration as e:
|
||||
result: AgentResult = e.value
|
||||
if result.usage:
|
||||
usage = result.usage if isinstance(result.usage, LLMUsage) else LLMUsage.empty_usage()
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=result.text,
|
||||
usage=usage,
|
||||
finish_reason=result.finish_reason,
|
||||
)
|
||||
return result
|
||||
|
||||
def _convert_agent_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
*,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> AgentLogEvent:
|
||||
return AgentLogEvent(
|
||||
message_id=log.id,
|
||||
label=log.label,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=log.parent_id,
|
||||
error=log.error,
|
||||
status=log.status.value,
|
||||
data=dict(log.data),
|
||||
metadata={k.value if hasattr(k, "value") else str(k): v for k, v in log.metadata.items()},
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def _convert_llm_chunk(
|
||||
self,
|
||||
chunk: LLMResultChunk,
|
||||
*,
|
||||
node_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
content = ""
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content = chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
|
||||
for item in chunk.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
content += item.data
|
||||
|
||||
if content:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=content,
|
||||
)
|
||||
370
api/core/workflow/nodes/agent_v2/node.py
Normal file
370
api/core/workflow/nodes/agent_v2/node.py
Normal file
@ -0,0 +1,370 @@
|
||||
"""Agent V2 Workflow Node.
|
||||
|
||||
A unified workflow node that combines LLM capabilities with agent tool-calling.
|
||||
When tools are configured, runs an FC/ReAct loop via StrategyFactory.
|
||||
When no tools are present, behaves as a single-shot LLM invocation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.node_events import (
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
|
||||
from .entities import AGENT_V2_NODE_TYPE, AgentV2NodeData
|
||||
from .event_adapter import AgentV2EventAdapter
|
||||
from .tool_manager import AgentV2ToolManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
class AgentV2Node(Node[AgentV2NodeData]):
|
||||
node_type = AGENT_V2_NODE_TYPE
|
||||
|
||||
_tool_manager: AgentV2ToolManager
|
||||
_event_adapter: AgentV2EventAdapter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
tool_manager: AgentV2ToolManager,
|
||||
event_adapter: AgentV2EventAdapter,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_manager = tool_manager
|
||||
self._event_adapter = event_adapter
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
|
||||
try:
|
||||
model_instance = self._fetch_model_instance(dify_ctx)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to load model: {e}",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
prompt_messages = self._build_prompt_messages(dify_ctx)
|
||||
|
||||
if self.node_data.tool_call_enabled:
|
||||
yield from self._run_with_tools(model_instance, prompt_messages, dify_ctx)
|
||||
else:
|
||||
yield from self._run_without_tools(model_instance, prompt_messages, dify_ctx)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# No-tools path: single LLM invocation (LLM Node equivalent)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_without_tools(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
dify_ctx: DifyRunContext,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
try:
|
||||
result_chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=True,
|
||||
user=dify_ctx.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
full_text = ""
|
||||
reasoning_content = ""
|
||||
usage: LLMUsage | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
for chunk in result_chunks:
|
||||
chunk_text = self._extract_chunk_text(chunk)
|
||||
if chunk_text:
|
||||
full_text += chunk_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=chunk_text,
|
||||
)
|
||||
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
if self.node_data.reasoning_format == "separated":
|
||||
full_text, reasoning_content = self._separate_reasoning(full_text)
|
||||
|
||||
if usage:
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
reasoning_content=reasoning_content or None,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"prompt_messages": [m.model_dump() for m in prompt_messages]},
|
||||
outputs={
|
||||
"text": full_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": usage.model_dump() if usage else {},
|
||||
"finish_reason": finish_reason or "stop",
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Agent V2 LLM invocation failed")
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tools path: agent loop via StrategyFactory
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_with_tools(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
dify_ctx: DifyRunContext,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
try:
|
||||
tool_instances = self._tool_manager.prepare_tool_instances(
|
||||
list(self.node_data.tools),
|
||||
)
|
||||
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
context = ExecutionContext(
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
conversation_id=get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.CONVERSATION_ID,
|
||||
),
|
||||
)
|
||||
|
||||
agent_strategy_enum = self._map_strategy_config(self.node_data.agent_strategy)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=tool_instances,
|
||||
files=[],
|
||||
max_iterations=self.node_data.max_iterations,
|
||||
context=context,
|
||||
agent_strategy=agent_strategy_enum,
|
||||
tool_invoke_hook=self._tool_manager.create_workflow_tool_invoke_hook(context),
|
||||
)
|
||||
|
||||
outputs_gen = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
stop=[],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
result = yield from self._event_adapter.process_strategy_outputs(
|
||||
outputs_gen,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"prompt_messages": [m.model_dump() for m in prompt_messages]},
|
||||
outputs={
|
||||
"text": result.text,
|
||||
"files": [f.model_dump() if hasattr(f, "model_dump") else str(f) for f in result.files],
|
||||
"usage": result.usage.model_dump() if hasattr(result.usage, "model_dump") else {},
|
||||
"finish_reason": result.finish_reason or "stop",
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Agent V2 tool execution failed")
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fetch_model_instance(self, dify_ctx: DifyRunContext) -> ModelInstance:
|
||||
model_config = self.node_data.model
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=model_config.provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_config.name,
|
||||
)
|
||||
return model_instance
|
||||
|
||||
def _build_prompt_messages(self, dify_ctx: DifyRunContext) -> list[PromptMessage]:
|
||||
"""Build prompt messages from the node's prompt_template, resolving variables."""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
messages: list[PromptMessage] = []
|
||||
|
||||
template = self.node_data.prompt_template
|
||||
if isinstance(template, Sequence) and not isinstance(template, str):
|
||||
for msg_template in template:
|
||||
role = msg_template.role.value if hasattr(msg_template.role, "value") else str(msg_template.role)
|
||||
text = msg_template.text or ""
|
||||
jinja2_text = getattr(msg_template, "jinja2_text", None)
|
||||
content = jinja2_text or text
|
||||
|
||||
resolved = VariableTemplateParser.resolve_template(content, variable_pool)
|
||||
|
||||
if role == "system":
|
||||
messages.append(SystemPromptMessage(content=resolved))
|
||||
elif role == "user":
|
||||
messages.append(UserPromptMessage(content=resolved))
|
||||
elif role == "assistant":
|
||||
messages.append(AssistantPromptMessage(content=resolved))
|
||||
else:
|
||||
text_content = getattr(template, "text", "") or ""
|
||||
resolved = VariableTemplateParser.resolve_template(text_content, variable_pool)
|
||||
messages.append(UserPromptMessage(content=resolved))
|
||||
|
||||
return messages
|
||||
|
||||
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
||||
try:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
return list(model_schema.features) if model_schema and model_schema.features else []
|
||||
except Exception:
|
||||
logger.warning("Failed to get model features, assuming none")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _map_strategy_config(
|
||||
config_value: Literal["auto", "function-calling", "chain-of-thought"],
|
||||
) -> AgentEntity.Strategy | None:
|
||||
mapping = {
|
||||
"function-calling": AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
"chain-of-thought": AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
}
|
||||
return mapping.get(config_value)
|
||||
|
||||
@staticmethod
|
||||
def _extract_chunk_text(chunk: LLMResultChunk) -> str:
|
||||
if not chunk.delta.message or not chunk.delta.message.content:
|
||||
return ""
|
||||
content = chunk.delta.message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
parts.append(item.data)
|
||||
return "".join(parts)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _separate_reasoning(text: str) -> tuple[str, str]:
|
||||
"""Extract <think> blocks from text, return (clean_text, reasoning_content)."""
|
||||
reasoning_parts = _THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(reasoning_parts)
|
||||
clean_text = _THINK_PATTERN.sub("", text).strip()
|
||||
return clean_text, reasoning_content
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentV2NodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
result: dict[str, list[str]] = {}
|
||||
|
||||
if isinstance(node_data.prompt_template, Sequence) and not isinstance(node_data.prompt_template, str):
|
||||
for msg in node_data.prompt_template:
|
||||
text = msg.text or ""
|
||||
jinja2_text = getattr(msg, "jinja2_text", None)
|
||||
content = jinja2_text or text
|
||||
selectors = VariableTemplateParser(content).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
else:
|
||||
text_content = getattr(node_data.prompt_template, "text", "") or ""
|
||||
selectors = VariableTemplateParser(text_content).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
|
||||
return {f"{node_id}.{key}": value for key, value in result.items()}
|
||||
122
api/core/workflow/nodes/agent_v2/tool_manager.py
Normal file
122
api/core/workflow/nodes/agent_v2/tool_manager.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""Tool management for Agent V2 Node.
|
||||
|
||||
Handles tool instance preparation, conversion to LLM-consumable format,
|
||||
and creation of workflow-compatible tool invoke hooks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import PromptMessageTool
|
||||
|
||||
from core.agent.entities import AgentToolEntity, ExecutionContext
|
||||
from core.agent.patterns.base import ToolInvokeHook
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta, ToolInvokeMessage
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .entities import ToolMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentV2ToolManager:
|
||||
"""Manages tool lifecycle for Agent V2 node execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def prepare_tool_instances(
|
||||
self,
|
||||
tools_config: list[ToolMetadata],
|
||||
) -> list[Tool]:
|
||||
"""Convert tool metadata configs into runtime Tool instances."""
|
||||
tool_instances: list[Tool] = []
|
||||
for tool_meta in tools_config:
|
||||
if not tool_meta.enabled:
|
||||
continue
|
||||
try:
|
||||
processed_settings = {}
|
||||
for key, value in tool_meta.settings.items():
|
||||
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
|
||||
if "type" in value["value"] and "value" in value["value"]:
|
||||
processed_settings[key] = value["value"]
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
|
||||
merged_parameters = {**tool_meta.parameters, **processed_settings}
|
||||
|
||||
agent_tool = AgentToolEntity(
|
||||
provider_id=tool_meta.provider_name,
|
||||
provider_type=tool_meta.type,
|
||||
tool_name=tool_meta.tool_name,
|
||||
tool_parameters=merged_parameters,
|
||||
plugin_unique_identifier=tool_meta.plugin_unique_identifier,
|
||||
credential_id=tool_meta.credential_id,
|
||||
)
|
||||
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
agent_tool=agent_tool,
|
||||
)
|
||||
tool_instances.append(tool_runtime)
|
||||
except Exception:
|
||||
logger.warning("Failed to prepare tool %s/%s, skipping", tool_meta.provider_name, tool_meta.tool_name, exc_info=True)
|
||||
continue
|
||||
|
||||
return tool_instances
|
||||
|
||||
def create_workflow_tool_invoke_hook(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
workflow_call_depth: int = 0,
|
||||
) -> ToolInvokeHook:
|
||||
"""Create a ToolInvokeHook for workflow context (uses generic_invoke)."""
|
||||
|
||||
def hook(
|
||||
tool: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
tool_response = ToolEngine.generic_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
app_id=context.app_id,
|
||||
conversation_id=context.conversation_id,
|
||||
)
|
||||
|
||||
response_content = ""
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
return response_content, [], ToolInvokeMeta.empty()
|
||||
|
||||
return hook
|
||||
@ -352,6 +352,7 @@ class AppMode(StrEnum):
|
||||
CHAT = "chat"
|
||||
ADVANCED_CHAT = "advanced-chat"
|
||||
AGENT_CHAT = "agent-chat"
|
||||
AGENT = "agent"
|
||||
CHANNEL = "channel"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
@ -483,7 +483,7 @@ class AppDslService:
|
||||
)
|
||||
|
||||
# Initialize app based on mode
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
@ -566,7 +566,7 @@ class AppDslService:
|
||||
},
|
||||
}
|
||||
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}:
|
||||
cls._append_workflow_export_data(
|
||||
export_data=export_data, app_model=app_model, include_secret=include_secret, workflow_id=workflow_id
|
||||
)
|
||||
|
||||
@ -147,6 +147,54 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.AGENT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from constants.model_template import default_app_templates
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
@ -169,6 +170,9 @@ class AppService:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if app_mode == AppMode.AGENT:
|
||||
self._init_agent_workflow(app, account, default_model_dict if default_model_config else None)
|
||||
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
@ -180,6 +184,34 @@ class AppService:
|
||||
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def _init_agent_workflow(app: App, account: Any, model_dict: dict | None) -> None:
|
||||
"""Create the default single-agent-node workflow for a new Agent app."""
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
model_config = model_dict or {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(
|
||||
model_config=model_config,
|
||||
is_chat=True,
|
||||
)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=graph,
|
||||
features={},
|
||||
unique_hash=None,
|
||||
account=account,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
def get_app(self, app: App) -> App:
|
||||
"""
|
||||
Get App
|
||||
|
||||
113
api/services/workflow/graph_factory.py
Normal file
113
api/services/workflow/graph_factory.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Factory for programmatically building workflow graphs.
|
||||
|
||||
Used by AppService to auto-generate single-node workflow graphs when
|
||||
creating a new Agent app (AppMode.AGENT).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE
|
||||
|
||||
|
||||
class WorkflowGraphFactory:
|
||||
"""Builds workflow graph dicts for special app creation flows."""
|
||||
|
||||
@staticmethod
|
||||
def create_single_agent_graph(
|
||||
model_config: dict[str, Any],
|
||||
is_chat: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a minimal start -> agent_v2 -> answer/end graph.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration dict with provider, name, mode, completion_params.
|
||||
is_chat: If True, creates chatflow (with answer node); otherwise workflow (with end node).
|
||||
|
||||
Returns:
|
||||
Graph dict with nodes and edges, ready for WorkflowService.sync_draft_workflow().
|
||||
"""
|
||||
agent_node_data: dict[str, Any] = {
|
||||
"type": AGENT_V2_NODE_TYPE,
|
||||
"title": "Agent",
|
||||
"model": model_config,
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "You are a helpful assistant."},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"tools": [],
|
||||
"max_iterations": 10,
|
||||
"agent_strategy": "auto",
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
}
|
||||
|
||||
if is_chat:
|
||||
agent_node_data["memory"] = {"window": {"enabled": True, "size": 50}}
|
||||
|
||||
nodes: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "custom",
|
||||
"data": {"type": "start", "title": "Start", "variables": []},
|
||||
"position": {"x": 80, "y": 282},
|
||||
},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "custom",
|
||||
"data": agent_node_data,
|
||||
"position": {"x": 400, "y": 282},
|
||||
},
|
||||
]
|
||||
|
||||
if is_chat:
|
||||
nodes.append(
|
||||
{
|
||||
"id": "answer",
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
"answer": "{{#agent.text#}}",
|
||||
},
|
||||
"position": {"x": 720, "y": 282},
|
||||
}
|
||||
)
|
||||
end_node_id = "answer"
|
||||
else:
|
||||
nodes.append(
|
||||
{
|
||||
"id": "end",
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"outputs": [
|
||||
{
|
||||
"value_selector": ["agent", "text"],
|
||||
"variable": "result",
|
||||
}
|
||||
],
|
||||
},
|
||||
"position": {"x": 720, "y": 282},
|
||||
}
|
||||
)
|
||||
end_node_id = "end"
|
||||
|
||||
edges: list[dict[str, str]] = [
|
||||
{
|
||||
"id": "start-agent",
|
||||
"source": "start",
|
||||
"target": "agent",
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
},
|
||||
{
|
||||
"id": f"agent-{end_node_id}",
|
||||
"source": "agent",
|
||||
"target": end_node_id,
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
},
|
||||
]
|
||||
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
@ -0,0 +1,332 @@
|
||||
"""Basic tests for Agent V2 node — Phase 1 + 2 validation.
|
||||
|
||||
Tests:
|
||||
1. Module imports resolve without errors
|
||||
2. AgentV2Node self-registers in the graphon Node registry
|
||||
3. DifyNodeFactory kwargs mapping includes agent-v2
|
||||
4. StrategyFactory selects correct strategy based on model features
|
||||
5. AgentV2NodeData validates with and without tools
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPhase1Imports:
|
||||
"""Verify Phase 1 (Agent Patterns) modules import correctly."""
|
||||
|
||||
def test_entities_import(self):
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
|
||||
assert ExecutionContext is not None
|
||||
assert AgentLog is not None
|
||||
assert AgentResult is not None
|
||||
|
||||
def test_entities_backward_compatible(self):
|
||||
from core.agent.entities import (
|
||||
AgentEntity,
|
||||
AgentInvokeMessage,
|
||||
AgentPromptEntity,
|
||||
AgentScratchpadUnit,
|
||||
AgentToolEntity,
|
||||
)
|
||||
|
||||
assert AgentEntity is not None
|
||||
assert AgentToolEntity is not None
|
||||
assert AgentPromptEntity is not None
|
||||
assert AgentScratchpadUnit is not None
|
||||
assert AgentInvokeMessage is not None
|
||||
|
||||
def test_patterns_module_import(self):
|
||||
from core.agent.patterns import (
|
||||
AgentPattern,
|
||||
FunctionCallStrategy,
|
||||
ReActStrategy,
|
||||
StrategyFactory,
|
||||
)
|
||||
|
||||
assert AgentPattern is not None
|
||||
assert FunctionCallStrategy is not None
|
||||
assert ReActStrategy is not None
|
||||
assert StrategyFactory is not None
|
||||
|
||||
def test_patterns_inheritance(self):
|
||||
from core.agent.patterns import AgentPattern, FunctionCallStrategy, ReActStrategy
|
||||
|
||||
assert issubclass(FunctionCallStrategy, AgentPattern)
|
||||
assert issubclass(ReActStrategy, AgentPattern)
|
||||
|
||||
|
||||
class TestPhase2Imports:
|
||||
"""Verify Phase 2 (Agent V2 Node) modules import correctly."""
|
||||
|
||||
def test_entities_import(self):
|
||||
from core.workflow.nodes.agent_v2.entities import (
|
||||
AGENT_V2_NODE_TYPE,
|
||||
AgentV2NodeData,
|
||||
ContextConfig,
|
||||
ToolMetadata,
|
||||
VisionConfig,
|
||||
)
|
||||
|
||||
assert AGENT_V2_NODE_TYPE == "agent-v2"
|
||||
assert AgentV2NodeData is not None
|
||||
assert ToolMetadata is not None
|
||||
|
||||
def test_node_import(self):
|
||||
from core.workflow.nodes.agent_v2.node import AgentV2Node
|
||||
|
||||
assert AgentV2Node is not None
|
||||
assert AgentV2Node.node_type == "agent-v2"
|
||||
|
||||
def test_tool_manager_import(self):
|
||||
from core.workflow.nodes.agent_v2.tool_manager import AgentV2ToolManager
|
||||
|
||||
assert AgentV2ToolManager is not None
|
||||
|
||||
def test_event_adapter_import(self):
|
||||
from core.workflow.nodes.agent_v2.event_adapter import AgentV2EventAdapter
|
||||
|
||||
assert AgentV2EventAdapter is not None
|
||||
|
||||
|
||||
class TestNodeRegistration:
|
||||
"""Verify AgentV2Node self-registers in the graphon Node registry."""
|
||||
|
||||
def test_agent_v2_in_registry(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
assert "agent-v2" in registry, f"agent-v2 not found in registry. Available: {list(registry.keys())}"
|
||||
|
||||
def test_agent_v2_latest_version(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
agent_v2_versions = registry.get("agent-v2", {})
|
||||
assert "latest" in agent_v2_versions
|
||||
assert "1" in agent_v2_versions
|
||||
|
||||
from core.workflow.nodes.agent_v2.node import AgentV2Node
|
||||
|
||||
assert agent_v2_versions["latest"] is AgentV2Node
|
||||
assert agent_v2_versions["1"] is AgentV2Node
|
||||
|
||||
def test_old_agent_still_registered(self):
|
||||
"""Old Agent node must not be affected by Agent V2."""
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
assert "agent" in registry, "Old agent node must still be registered"
|
||||
|
||||
def test_resolve_workflow_node_class(self):
|
||||
from core.workflow.node_factory import register_nodes, resolve_workflow_node_class
|
||||
from core.workflow.nodes.agent_v2.node import AgentV2Node
|
||||
|
||||
register_nodes()
|
||||
|
||||
resolved = resolve_workflow_node_class(node_type="agent-v2", node_version="1")
|
||||
assert resolved is AgentV2Node
|
||||
|
||||
resolved_latest = resolve_workflow_node_class(node_type="agent-v2", node_version="latest")
|
||||
assert resolved_latest is AgentV2Node
|
||||
|
||||
|
||||
class TestNodeFactoryKwargs:
|
||||
"""Verify DifyNodeFactory includes agent-v2 in kwargs mapping."""
|
||||
|
||||
def test_agent_v2_node_type_in_factory(self):
|
||||
from core.workflow.node_factory import AGENT_V2_NODE_TYPE
|
||||
|
||||
assert AGENT_V2_NODE_TYPE == "agent-v2"
|
||||
|
||||
|
||||
class TestStrategyFactory:
|
||||
"""Verify StrategyFactory selects correct strategy."""
|
||||
|
||||
def test_fc_selected_for_tool_call_model(self):
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from core.agent.patterns import FunctionCallStrategy, StrategyFactory
|
||||
|
||||
assert ModelFeature.TOOL_CALL in StrategyFactory.TOOL_CALL_FEATURES
|
||||
assert ModelFeature.MULTI_TOOL_CALL in StrategyFactory.TOOL_CALL_FEATURES
|
||||
|
||||
def test_factory_has_create_strategy(self):
|
||||
from core.agent.patterns import StrategyFactory
|
||||
|
||||
assert callable(getattr(StrategyFactory, "create_strategy", None))
|
||||
|
||||
|
||||
class TestAgentV2NodeData:
|
||||
"""Verify AgentV2NodeData validation."""
|
||||
|
||||
def test_minimal_data(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "system", "text": "You are helpful."}, {"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
)
|
||||
assert data.type == "agent-v2"
|
||||
assert data.tool_call_enabled is False
|
||||
assert data.max_iterations == 10
|
||||
assert data.agent_strategy == "auto"
|
||||
|
||||
def test_data_with_tools(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent with Tools",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Search for {{query}}"}],
|
||||
context={"enabled": False},
|
||||
tools=[
|
||||
{
|
||||
"enabled": True,
|
||||
"type": "builtin",
|
||||
"provider_name": "google",
|
||||
"tool_name": "google_search",
|
||||
}
|
||||
],
|
||||
max_iterations=5,
|
||||
agent_strategy="function-calling",
|
||||
)
|
||||
assert data.tool_call_enabled is True
|
||||
assert data.max_iterations == 5
|
||||
assert data.agent_strategy == "function-calling"
|
||||
assert len(data.tools) == 1
|
||||
|
||||
def test_data_with_disabled_tools(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
tools=[
|
||||
{
|
||||
"enabled": False,
|
||||
"type": "builtin",
|
||||
"provider_name": "google",
|
||||
"tool_name": "google_search",
|
||||
}
|
||||
],
|
||||
)
|
||||
assert data.tool_call_enabled is False
|
||||
|
||||
def test_data_with_memory(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
memory={"window": {"enabled": True, "size": 50}},
|
||||
)
|
||||
assert data.memory is not None
|
||||
assert data.memory.window.enabled is True
|
||||
assert data.memory.window.size == 50
|
||||
|
||||
def test_data_with_vision(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
vision={"enabled": True},
|
||||
)
|
||||
assert data.vision.enabled is True
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Verify ExecutionContext entity."""
|
||||
|
||||
def test_create_minimal(self):
|
||||
from core.agent.entities import ExecutionContext
|
||||
|
||||
ctx = ExecutionContext.create_minimal(user_id="user-123")
|
||||
assert ctx.user_id == "user-123"
|
||||
assert ctx.app_id is None
|
||||
|
||||
def test_to_dict(self):
|
||||
from core.agent.entities import ExecutionContext
|
||||
|
||||
ctx = ExecutionContext(user_id="u1", app_id="a1", tenant_id="t1")
|
||||
d = ctx.to_dict()
|
||||
assert d["user_id"] == "u1"
|
||||
assert d["app_id"] == "a1"
|
||||
assert d["tenant_id"] == "t1"
|
||||
assert d["conversation_id"] is None
|
||||
|
||||
def test_with_updates(self):
|
||||
from core.agent.entities import ExecutionContext
|
||||
|
||||
ctx = ExecutionContext(user_id="u1")
|
||||
ctx2 = ctx.with_updates(app_id="a1", conversation_id="c1")
|
||||
assert ctx2.user_id == "u1"
|
||||
assert ctx2.app_id == "a1"
|
||||
assert ctx2.conversation_id == "c1"
|
||||
|
||||
|
||||
class TestAgentLog:
|
||||
"""Verify AgentLog entity."""
|
||||
|
||||
def test_create_log(self):
|
||||
from core.agent.entities import AgentLog
|
||||
|
||||
log = AgentLog(
|
||||
label="Round 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
assert log.id is not None
|
||||
assert log.label == "Round 1"
|
||||
assert log.log_type == "round"
|
||||
assert log.status == "start"
|
||||
assert log.parent_id is None
|
||||
|
||||
def test_log_types(self):
|
||||
from core.agent.entities import AgentLog
|
||||
|
||||
assert AgentLog.LogType.ROUND == "round"
|
||||
assert AgentLog.LogType.THOUGHT == "thought"
|
||||
assert AgentLog.LogType.TOOL_CALL == "tool_call"
|
||||
|
||||
|
||||
class TestAgentResult:
|
||||
"""Verify AgentResult entity."""
|
||||
|
||||
def test_default_result(self):
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
result = AgentResult()
|
||||
assert result.text == ""
|
||||
assert result.files == []
|
||||
assert result.usage is None
|
||||
assert result.finish_reason is None
|
||||
|
||||
def test_result_with_data(self):
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
result = AgentResult(text="Hello world", finish_reason="stop")
|
||||
assert result.text == "Hello world"
|
||||
assert result.finish_reason == "stop"
|
||||
@ -0,0 +1,132 @@
|
||||
"""Tests for Phase 3 — Agent app type support."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAppModeAgent:
|
||||
"""Verify AppMode.AGENT is properly defined."""
|
||||
|
||||
def test_agent_mode_exists(self):
|
||||
from models.model import AppMode
|
||||
|
||||
assert hasattr(AppMode, "AGENT")
|
||||
assert AppMode.AGENT == "agent"
|
||||
|
||||
def test_agent_mode_value_of(self):
|
||||
from models.model import AppMode
|
||||
|
||||
mode = AppMode.value_of("agent")
|
||||
assert mode == AppMode.AGENT
|
||||
|
||||
def test_all_original_modes_still_work(self):
|
||||
from models.model import AppMode
|
||||
|
||||
for val in ["completion", "workflow", "chat", "advanced-chat", "agent-chat", "channel", "rag-pipeline"]:
|
||||
mode = AppMode.value_of(val)
|
||||
assert mode.value == val
|
||||
|
||||
|
||||
class TestDefaultAppTemplate:
|
||||
"""Verify AGENT template is defined."""
|
||||
|
||||
def test_agent_template_exists(self):
|
||||
from constants.model_template import default_app_templates
|
||||
from models.model import AppMode
|
||||
|
||||
assert AppMode.AGENT in default_app_templates
|
||||
template = default_app_templates[AppMode.AGENT]
|
||||
assert template["app"]["mode"] == AppMode.AGENT
|
||||
assert template["app"]["enable_site"] is True
|
||||
assert "model_config" in template
|
||||
|
||||
def test_all_original_templates_exist(self):
|
||||
from constants.model_template import default_app_templates
|
||||
from models.model import AppMode
|
||||
|
||||
for mode in [AppMode.WORKFLOW, AppMode.COMPLETION, AppMode.CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT]:
|
||||
assert mode in default_app_templates
|
||||
|
||||
|
||||
class TestWorkflowGraphFactory:
|
||||
"""Verify WorkflowGraphFactory creates valid graphs."""
|
||||
|
||||
def test_create_chat_graph(self):
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
|
||||
model_config = {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(model_config, is_chat=True)
|
||||
|
||||
assert "nodes" in graph
|
||||
assert "edges" in graph
|
||||
assert len(graph["nodes"]) == 3
|
||||
assert len(graph["edges"]) == 2
|
||||
|
||||
node_types = [n["data"]["type"] for n in graph["nodes"]]
|
||||
assert "start" in node_types
|
||||
assert "agent-v2" in node_types
|
||||
assert "answer" in node_types
|
||||
|
||||
agent_node = next(n for n in graph["nodes"] if n["data"]["type"] == "agent-v2")
|
||||
assert agent_node["data"]["model"] == model_config
|
||||
assert agent_node["data"]["memory"] is not None
|
||||
|
||||
def test_create_workflow_graph(self):
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
|
||||
model_config = {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(model_config, is_chat=False)
|
||||
|
||||
node_types = [n["data"]["type"] for n in graph["nodes"]]
|
||||
assert "end" in node_types
|
||||
assert "answer" not in node_types
|
||||
|
||||
agent_node = next(n for n in graph["nodes"] if n["data"]["type"] == "agent-v2")
|
||||
assert agent_node["data"].get("memory") is None
|
||||
|
||||
def test_edge_connectivity(self):
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(
|
||||
{"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
is_chat=True,
|
||||
)
|
||||
|
||||
edges = graph["edges"]
|
||||
sources = {e["source"] for e in edges}
|
||||
targets = {e["target"] for e in edges}
|
||||
assert "start" in sources
|
||||
assert "agent" in sources
|
||||
assert "agent" in targets
|
||||
assert "answer" in targets
|
||||
|
||||
|
||||
class TestConsoleAppController:
|
||||
"""Verify Console API allows 'agent' mode."""
|
||||
|
||||
def test_allow_create_app_modes(self):
|
||||
from controllers.console.app.app import ALLOW_CREATE_APP_MODES
|
||||
|
||||
assert "agent" in ALLOW_CREATE_APP_MODES
|
||||
assert "chat" in ALLOW_CREATE_APP_MODES
|
||||
assert "agent-chat" in ALLOW_CREATE_APP_MODES
|
||||
|
||||
|
||||
class TestAppGenerateServiceHasAgentCase:
|
||||
"""Verify the generate() method has an AppMode.AGENT case."""
|
||||
|
||||
def test_generate_method_exists(self):
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
assert hasattr(AppGenerateService, "generate")
|
||||
|
||||
def test_agent_mode_import(self):
|
||||
"""Verify AppMode.AGENT can be used in match statement context."""
|
||||
from models.model import AppMode
|
||||
|
||||
mode = AppMode.AGENT
|
||||
match mode:
|
||||
case AppMode.AGENT:
|
||||
result = "agent"
|
||||
case _:
|
||||
result = "other"
|
||||
assert result == "agent"
|
||||
Loading…
Reference in New Issue
Block a user