diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f8213d9fd7..4d29f419d1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1047,6 +1047,8 @@ class ToolManager: continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": + if not isinstance(tool_input.value, list): + raise ToolParameterError(f"Invalid variable selector for {parameter.name}") variable = variable_pool.get(tool_input.value) if variable is None: raise ToolParameterError(f"Variable {tool_input.value} does not exist") @@ -1056,6 +1058,11 @@ class ToolManager: elif tool_input.type == "mixed": segment_group = variable_pool.convert_template(str(tool_input.value)) parameter_value = segment_group.text + elif tool_input.type == "mention": + # Mention type not supported in agent mode + raise ToolParameterError( + f"Mention type not supported in agent for parameter '{parameter.name}'" + ) else: raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") runtime_parameters[parameter.name] = parameter_value diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 7a1cbf9940..e00cfceedd 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -4,6 +4,7 @@ from .segments import ( ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, + ArrayPromptMessageSegment, ArraySegment, ArrayStringSegment, FileSegment, @@ -20,6 +21,7 @@ from .variables import ( ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, + ArrayPromptMessageVariable, ArrayStringVariable, ArrayVariable, FileVariable, @@ -41,6 +43,8 @@ __all__ = [ "ArrayNumberVariable", "ArrayObjectSegment", "ArrayObjectVariable", + "ArrayPromptMessageSegment", + "ArrayPromptMessageVariable", "ArraySegment", "ArrayStringSegment", "ArrayStringVariable", diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 406b4e6f93..61bd62628c 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -6,6 +6,7 @@ from typing import Annotated, Any, TypeAlias from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator from core.file import File +from core.model_runtime.entities import PromptMessage from .types import SegmentType @@ -208,6 +209,15 @@ class ArrayBooleanSegment(ArraySegment): value: Sequence[bool] +class ArrayPromptMessageSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE + value: Sequence[PromptMessage] + + def to_object(self): + """Convert to JSON-serializable format for database storage and frontend.""" + return [msg.model_dump() for msg in self.value] + + def get_segment_discriminator(v: Any) -> SegmentType | None: if isinstance(v, Segment): return v.value_type @@ -248,6 +258,7 @@ SegmentUnion: TypeAlias = Annotated[ | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] + | Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)] ), Discriminator(get_segment_discriminator), ] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 13b926c978..ac055ae232 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -45,6 +45,7 @@ class SegmentType(StrEnum): ARRAY_OBJECT = "array[object]" ARRAY_FILE = "array[file]" ARRAY_BOOLEAN = "array[boolean]" + ARRAY_PROMPT_MESSAGE = "array[message]" NONE = "none" diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index 8e738f8fd5..799a923084 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -3,8 +3,10 @@ from typing import Any import orjson +from core.model_runtime.entities import PromptMessage + from .segment_group import SegmentGroup -from .segments import ArrayFileSegment, FileSegment, Segment +from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: @@ -16,7 +18,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[ def segment_orjson_default(o: Any): """Default function for orjson serialization of Segment types""" - if isinstance(o, ArrayFileSegment): + if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)): return [v.model_dump() for v in o.value] elif isinstance(o, FileSegment): return o.value.model_dump() @@ -24,6 +26,8 @@ def segment_orjson_default(o: Any): return [segment_orjson_default(seg) for seg in o.value] elif isinstance(o, Segment): return o.value + elif isinstance(o, PromptMessage): + return o.model_dump() raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 9fd0bbc5b2..5ef13ad4f0 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -12,6 +12,7 @@ from .segments import ( ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, + ArrayPromptMessageSegment, ArraySegment, ArrayStringSegment, BooleanSegment, @@ -110,6 +111,10 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): pass +class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable): + pass + + class RAGPipelineVariable(BaseModel): belong_to_node_id: str = Field(description="belong to which node id, shared means public") type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") @@ -160,6 +165,7 @@ VariableUnion: TypeAlias = Annotated[ | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] | Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)] + | Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)] | Annotated[SecretVariable, Tag(SegmentType.SECRET)] ), Discriminator(get_segment_discriminator), diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index d38d1eba96..bd2326e84f 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -311,9 +311,9 @@ class Graph: # - custom-note: top-level type (node_config.type == "custom-note") # - group: data-level type (node_config.data.type == "group") node_configs = [ - node_config for node_config in node_configs - if node_config.get("type", "") != "custom-note" - and node_config.get("data", {}).get("type", "") != "group" + node_config + for node_config in node_configs + if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group" ] # Parse node configurations diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 7a10b0f291..c90faf6e5e 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -125,9 +125,9 @@ class EventHandler: Args: event: The node started event """ - # Check if this is a virtual node (extraction node) - if self._is_virtual_node(event.node_id): - self._handle_virtual_node_started(event) + # Check if this is an extractor node (has parent_node_id) + if self._is_extractor_node(event.node_id): + self._handle_extractor_node_started(event) return # Track execution in domain model @@ -169,9 +169,9 @@ class EventHandler: Args: event: The node succeeded event """ - # Check if this is a virtual node (extraction node) - if self._is_virtual_node(event.node_id): - self._handle_virtual_node_success(event) + # Check if this is an extractor node (has parent_node_id) + if self._is_extractor_node(event.node_id): + self._handle_extractor_node_success(event) return # Update domain model @@ -236,9 +236,9 @@ class EventHandler: Args: event: The node failed event """ - # Check if this is a virtual node (extraction node) - if self._is_virtual_node(event.node_id): - self._handle_virtual_node_failed(event) + # Check if this is an extractor node (has parent_node_id) + if self._is_extractor_node(event.node_id): + self._handle_extractor_node_failed(event) return # Update domain model @@ -361,23 +361,23 @@ class EventHandler: else: self._graph_runtime_state.set_output(key, value) - def _is_virtual_node(self, node_id: str) -> bool: + def _is_extractor_node(self, node_id: str) -> bool: """ - Check if node_id represents a virtual sub-node. + Check if node_id represents an extractor node (has parent_node_id). - Virtual nodes have IDs in the format: {parent_node_id}.{local_id} - We check if the part before '.' exists in graph nodes. + Extractor nodes extract values from list[PromptMessage] for their parent node. + They have a parent_node_id field pointing to their parent node. """ - if "." in node_id: - parent_id = node_id.rsplit(".", 1)[0] - return parent_id in self._graph.nodes - return False + node = self._graph.nodes.get(node_id) + if node is None: + return False + return node.node_data.is_extractor_node - def _handle_virtual_node_started(self, event: NodeRunStartedEvent) -> None: + def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None: """ - Handle virtual node started event. + Handle extractor node started event. - Virtual nodes don't need full execution tracking, just collect the event. + Extractor nodes don't need full execution tracking, just collect the event. """ # Track in response coordinator for stream ordering self._response_coordinator.track_node_execution(event.node_id, event.id) @@ -385,11 +385,11 @@ class EventHandler: # Collect the event self._event_collector.collect(event) - def _handle_virtual_node_success(self, event: NodeRunSucceededEvent) -> None: + def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None: """ - Handle virtual node success event. + Handle extractor node success event. - Virtual nodes (extraction nodes) need special handling: + Extractor nodes need special handling: - Store outputs in variable pool (for reference by other nodes) - Accumulate token usage - Collect the event for logging @@ -403,11 +403,11 @@ class EventHandler: # Collect the event self._event_collector.collect(event) - def _handle_virtual_node_failed(self, event: NodeRunFailedEvent) -> None: + def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None: """ - Handle virtual node failed event. + Handle extractor node failed event. - Virtual nodes (extraction nodes) failures are collected for logging, + Extractor node failures are collected for logging, but the parent node is responsible for handling the error. """ self._accumulate_node_usage(event.node_run_result.llm_usage) diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 52345ece82..f225798d41 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -20,12 +20,6 @@ class NodeRunStartedEvent(GraphNodeEventBase): provider_type: str = "" provider_id: str = "" - # Virtual node fields for extraction - is_virtual: bool = False - parent_node_id: str | None = None - extraction_source: str | None = None # e.g., "llm1.context" - extraction_prompt: str | None = None - class NodeRunStreamChunkEvent(GraphNodeEventBase): # Spec-compliant fields diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py index e6cde91bea..87fd6c5b32 100644 --- a/api/core/workflow/nodes/base/__init__.py +++ b/api/core/workflow/nodes/base/__init__.py @@ -4,10 +4,8 @@ from .entities import ( BaseLoopNodeData, BaseLoopState, BaseNodeData, - VirtualNodeConfig, ) from .usage_tracking_mixin import LLMUsageTrackingMixin -from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor __all__ = [ "BaseIterationNodeData", @@ -16,7 +14,4 @@ __all__ = [ "BaseLoopState", "BaseNodeData", "LLMUsageTrackingMixin", - "VirtualNodeConfig", - "VirtualNodeExecutionError", - "VirtualNodeExecutor", ] diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 41469d6ee8..fa8673db5f 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -167,24 +167,6 @@ class DefaultValue(BaseModel): return self -class VirtualNodeConfig(BaseModel): - """Configuration for a virtual sub-node embedded within a parent node.""" - - # Local ID within parent node (e.g., "ext_1") - # Will be converted to global ID: "{parent_id}.{id}" - id: str - - # Node type (e.g., "llm", "code", "tool") - type: str - - # Full node data configuration - data: dict[str, Any] = {} - - def get_global_id(self, parent_node_id: str) -> str: - """Get the global node ID by combining parent ID and local ID.""" - return f"{parent_node_id}.{self.id}" - - class BaseNodeData(ABC, BaseModel): title: str desc: str | None = None @@ -193,8 +175,15 @@ class BaseNodeData(ABC, BaseModel): default_value: list[DefaultValue] | None = None retry_config: RetryConfig = RetryConfig() - # Virtual sub-nodes that execute before the main node - virtual_nodes: list[VirtualNodeConfig] = [] + # Parent node ID when this node is used as an extractor. + # If set, this node is an "attached" extractor node that extracts values + # from list[PromptMessage] for the parent node's parameters. + parent_node_id: str | None = None + + @property + def is_extractor_node(self) -> bool: + """Check if this node is an extractor node (has parent_node_id).""" + return self.parent_node_id is not None @property def default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index d49910c9fb..50314ea630 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]): self._node_id = node_id self._node_execution_id: str = "" self._start_at = naive_utc_now() - self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes raw_node_data = config.get("data") or {} if not isinstance(raw_node_data, Mapping): @@ -271,51 +270,81 @@ class Node(Generic[NodeDataT]): """Check if execution should be stopped.""" return self.graph_runtime_state.stop_event.is_set() - def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]: + def _find_extractor_node_configs(self) -> list[dict[str, Any]]: """ - Execute all virtual sub-nodes defined in node configuration. - - Virtual nodes are complete node definitions that execute before the main node. - Each virtual node: - - Has its own global ID: "{parent_id}.{local_id}" - - Generates standard node events - - Stores outputs in the variable pool (via event handling) - - Supports retry via parent node's retry config + Find all extractor node configurations that have parent_node_id == self._node_id. Returns: - dict mapping local_id -> outputs dict + List of node configuration dicts for extractor nodes """ - from .virtual_node_executor import VirtualNodeExecutor + nodes = self.graph_config.get("nodes", []) + extractor_configs = [] + for node_config in nodes: + node_data = node_config.get("data", {}) + if node_data.get("parent_node_id") == self._node_id: + extractor_configs.append(node_config) + return extractor_configs - virtual_nodes = self.node_data.virtual_nodes - if not virtual_nodes: - return {} - - executor = VirtualNodeExecutor( - graph_init_params=self._graph_init_params, - graph_runtime_state=self.graph_runtime_state, - parent_node_id=self._node_id, - parent_retry_config=self.retry_config, - ) - - return (yield from executor.execute_virtual_nodes(virtual_nodes)) - - @property - def virtual_node_outputs(self) -> dict[str, Any]: + def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]: """ - Get the outputs from virtual sub-nodes. + Execute all extractor nodes associated with this node. - Returns: - dict mapping local_id -> outputs dict + Extractor nodes are nodes with parent_node_id == self._node_id. + They are executed before the main node to extract values from list[PromptMessage]. """ - return self._virtual_node_outputs + from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING + + extractor_configs = self._find_extractor_node_configs() + logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id) + if not extractor_configs: + return + + for config in extractor_configs: + node_id = config.get("id") + node_data = config.get("data", {}) + node_type_str = node_data.get("type") + + if not node_id or not node_type_str: + continue + + # Get node class + try: + node_type = NodeType(node_type_str) + except ValueError: + continue + + node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) + if not node_mapping: + continue + + node_version = str(node_data.get("version", "1")) + node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION) + if not node_cls: + continue + + # Instantiate and execute the extractor node + extractor_node = node_cls( + id=node_id, + config=config, + graph_init_params=self._graph_init_params, + graph_runtime_state=self.graph_runtime_state, + ) + + # Execute and process extractor node events + for event in extractor_node.run(): + if isinstance(event, NodeRunSucceededEvent): + # Store extractor node outputs in variable pool + outputs = event.node_run_result.outputs + for variable_name, variable_value in outputs.items(): + self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) + yield event def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() - # Step 1: Execute virtual sub-nodes before main node execution - self._virtual_node_outputs = yield from self._execute_virtual_nodes() + # Step 1: Execute associated extractor nodes before main node execution + yield from self._execute_extractor_nodes() # Create and push start event with required fields start_event = NodeRunStartedEvent( diff --git a/api/core/workflow/nodes/base/virtual_node_executor.py b/api/core/workflow/nodes/base/virtual_node_executor.py deleted file mode 100644 index 3f3b8f1f99..0000000000 --- a/api/core/workflow/nodes/base/virtual_node_executor.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Virtual Node Executor for running embedded sub-nodes within a parent node. - -This module handles the execution of virtual nodes defined in a parent node's -`virtual_nodes` configuration. Virtual nodes are complete node definitions -that execute before the parent node. - -Example configuration: - virtual_nodes: - - id: ext_1 - type: llm - data: - model: {...} - prompt_template: [...] -""" - -import time -from collections.abc import Generator -from typing import TYPE_CHECKING, Any -from uuid import uuid4 - -from core.workflow.enums import NodeType -from core.workflow.graph_events import ( - GraphNodeEventBase, - NodeRunFailedEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from libs.datetime_utils import naive_utc_now - -from .entities import RetryConfig, VirtualNodeConfig - -if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState - - -class VirtualNodeExecutionError(Exception): - """Error during virtual node execution""" - - def __init__(self, node_id: str, original_error: Exception): - self.node_id = node_id - self.original_error = original_error - super().__init__(f"Virtual node {node_id} execution failed: {original_error}") - - -class VirtualNodeExecutor: - """ - Executes virtual sub-nodes embedded within a parent node. - - Virtual nodes are complete node definitions that execute before the parent node. - Each virtual node: - - Has its own global ID: "{parent_id}.{local_id}" - - Generates standard node events - - Stores outputs in the variable pool - - Supports retry via parent node's retry config - """ - - def __init__( - self, - *, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - parent_node_id: str, - parent_retry_config: RetryConfig | None = None, - ): - self._graph_init_params = graph_init_params - self._graph_runtime_state = graph_runtime_state - self._parent_node_id = parent_node_id - self._parent_retry_config = parent_retry_config or RetryConfig() - - def execute_virtual_nodes( - self, - virtual_nodes: list[VirtualNodeConfig], - ) -> Generator[GraphNodeEventBase, None, dict[str, Any]]: - """ - Execute all virtual nodes in order. - - Args: - virtual_nodes: List of virtual node configurations - - Yields: - Node events from each virtual node execution - - Returns: - dict mapping local_id -> outputs dict - """ - results: dict[str, Any] = {} - - for vnode_config in virtual_nodes: - global_id = vnode_config.get_global_id(self._parent_node_id) - - # Execute with retry - outputs = yield from self._execute_with_retry(vnode_config, global_id) - results[vnode_config.id] = outputs - - return results - - def _execute_with_retry( - self, - vnode_config: VirtualNodeConfig, - global_id: str, - ) -> Generator[GraphNodeEventBase, None, dict[str, Any]]: - """ - Execute virtual node with retry support. - """ - retry_config = self._parent_retry_config - last_error: Exception | None = None - - for attempt in range(retry_config.max_retries + 1): - try: - return (yield from self._execute_single_node(vnode_config, global_id)) - except Exception as e: - last_error = e - - if attempt < retry_config.max_retries: - # Yield retry event - yield NodeRunRetryEvent( - id=str(uuid4()), - node_id=global_id, - node_type=self._get_node_type(vnode_config.type), - node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"), - start_at=naive_utc_now(), - error=str(e), - retry_index=attempt + 1, - ) - - time.sleep(retry_config.retry_interval_seconds) - continue - - raise VirtualNodeExecutionError(global_id, e) from e - - raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error")) - - def _execute_single_node( - self, - vnode_config: VirtualNodeConfig, - global_id: str, - ) -> Generator[GraphNodeEventBase, None, dict[str, Any]]: - """ - Execute a single virtual node by instantiating and running it. - """ - from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING - - # Build node config - node_config: dict[str, Any] = { - "id": global_id, - "data": { - **vnode_config.data, - "title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"), - }, - } - - # Get the node class for this type - node_type = self._get_node_type(vnode_config.type) - node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) - if not node_mapping: - raise ValueError(f"No class mapping found for node type: {node_type}") - - node_version = str(vnode_config.data.get("version", "1")) - node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION) - if not node_cls: - raise ValueError(f"No class found for node type: {node_type}") - - # Instantiate the node - node = node_cls( - id=global_id, - config=node_config, - graph_init_params=self._graph_init_params, - graph_runtime_state=self._graph_runtime_state, - ) - - # Run and collect events - outputs: dict[str, Any] = {} - - for event in node.run(): - # Mark event as coming from virtual node - self._mark_event_as_virtual(event, vnode_config) - yield event - - if isinstance(event, NodeRunSucceededEvent): - outputs = event.node_run_result.outputs or {} - elif isinstance(event, NodeRunFailedEvent): - raise Exception(event.error or "Virtual node execution failed") - - return outputs - - def _mark_event_as_virtual( - self, - event: GraphNodeEventBase, - vnode_config: VirtualNodeConfig, - ) -> None: - """Mark event as coming from a virtual node.""" - if isinstance(event, NodeRunStartedEvent): - event.is_virtual = True - event.parent_node_id = self._parent_node_id - - def _get_node_type(self, type_str: str) -> NodeType: - """Convert type string to NodeType enum.""" - type_mapping = { - "llm": NodeType.LLM, - "code": NodeType.CODE, - "tool": NodeType.TOOL, - "if-else": NodeType.IF_ELSE, - "question-classifier": NodeType.QUESTION_CLASSIFIER, - "parameter-extractor": NodeType.PARAMETER_EXTRACTOR, - "template-transform": NodeType.TEMPLATE_TRANSFORM, - "variable-assigner": NodeType.VARIABLE_ASSIGNER, - "http-request": NodeType.HTTP_REQUEST, - "knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL, - } - return type_mapping.get(type_str, NodeType.LLM) diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index fe6f2290aa..c7db88891f 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Annotated, Any, Literal, TypeAlias -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from core.model_runtime.entities import ImagePromptMessageContent, LLMMode from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig @@ -58,9 +58,28 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): jinja2_text: str | None = None +class PromptMessageContext(BaseModel): + """Context variable reference in prompt template. + + YAML/JSON format: { "$context": ["node_id", "variable_name"] } + This will be expanded to list[PromptMessage] at runtime. + """ + + model_config = ConfigDict(populate_by_name=True) + + value_selector: Sequence[str] = Field(alias="$context") + + +# Union type for prompt template items (static message or context variable reference) +PromptTemplateItem: TypeAlias = Annotated[ + LLMNodeChatModelMessage | PromptMessageContext, + Field(discriminator=None), +] + + class LLMNodeData(BaseNodeData): model: ModelConfig - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate + prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: MemoryConfig | None = None context: ContextConfig diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index e69186e2b0..02ab4ee7a0 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,7 +7,7 @@ import logging import re import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import select @@ -52,6 +52,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.signature import sign_upload_file from core.variables import ( ArrayFileSegment, + ArrayPromptMessageSegment, ArraySegment, FileSegment, NoneSegment, @@ -88,6 +89,7 @@ from .entities import ( LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + PromptMessageContext, ) from .exc import ( InvalidContextStructureError, @@ -160,8 +162,9 @@ class LLMNode(Node[LLMNodeData]): variable_pool = self.graph_runtime_state.variable_pool try: - # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) + # Parse prompt template to separate static messages and context references + prompt_template = self.node_data.prompt_template + static_messages, context_refs, template_order = self._parse_prompt_template() # fetch variables and fetch values from variable pool inputs = self._fetch_inputs(node_data=self.node_data) @@ -223,21 +226,40 @@ class LLMNode(Node[LLMNodeData]): ): query = query_variable.text - prompt_messages, stop = LLMNode.fetch_prompt_messages( - sys_query=query, - sys_files=files, - context=context, - memory=memory, - model_config=model_config, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, - tenant_id=self.tenant_id, - context_files=context_files, - ) + # Get prompt messages + prompt_messages: Sequence[PromptMessage] + stop: Sequence[str] | None + if isinstance(prompt_template, list) and context_refs: + prompt_messages, stop = self._build_prompt_messages_with_context( + context_refs=context_refs, + template_order=template_order, + static_messages=static_messages, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config, + context_files=context_files, + ) + else: + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query=query, + sys_files=files, + context=context, + memory=memory, + model_config=model_config, + prompt_template=cast( + Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + self.node_data.prompt_template, + ), + memory_config=self.node_data.memory, + vision_enabled=self.node_data.vision.enabled, + vision_detail=self.node_data.vision.configs.detail, + variable_pool=variable_pool, + jinja2_variables=self.node_data.prompt_config.jinja2_variables, + tenant_id=self.tenant_id, + context_files=context_files, + ) # handle invoke result generator = LLMNode.invoke_llm( @@ -304,7 +326,7 @@ class LLMNode(Node[LLMNodeData]): "reasoning_content": reasoning_content, "usage": jsonable_encoder(usage), "finish_reason": finish_reason, - "context": self._build_context(prompt_messages, clean_text, model_config.mode), + "context": self._build_context(prompt_messages, clean_text), } if structured_output: outputs["structured_output"] = structured_output.structured_output @@ -602,17 +624,15 @@ class LLMNode(Node[LLMNodeData]): def _build_context( prompt_messages: Sequence[PromptMessage], assistant_response: str, - model_mode: str, - ) -> list[dict[str, Any]]: + ) -> list[PromptMessage]: """ Build context from prompt messages and assistant response. Excludes system messages and includes the current LLM response. + Returns list[PromptMessage] for use with ArrayPromptMessageSegment. """ context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM] context_messages.append(AssistantPromptMessage(content=assistant_response)) - return PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_mode, prompt_messages=context_messages - ) + return context_messages def _transform_chat_messages( self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / @@ -629,6 +649,106 @@ class LLMNode(Node[LLMNodeData]): return messages + def _parse_prompt_template( + self, + ) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]: + """ + Parse prompt_template to separate static messages and context references. + + Returns: + Tuple of (static_messages, context_refs, template_order) + - static_messages: list of LLMNodeChatModelMessage + - context_refs: list of PromptMessageContext + - template_order: list of (index, type) tuples preserving original order + """ + prompt_template = self.node_data.prompt_template + static_messages: list[LLMNodeChatModelMessage] = [] + context_refs: list[PromptMessageContext] = [] + template_order: list[tuple[int, str]] = [] + + if isinstance(prompt_template, list): + for idx, item in enumerate(prompt_template): + if isinstance(item, PromptMessageContext): + context_refs.append(item) + template_order.append((idx, "context")) + else: + static_messages.append(item) + template_order.append((idx, "static")) + # Transform static messages for jinja2 + if static_messages: + self.node_data.prompt_template = self._transform_chat_messages(static_messages) + + return static_messages, context_refs, template_order + + def _build_prompt_messages_with_context( + self, + *, + context_refs: list[PromptMessageContext], + template_order: list[tuple[int, str]], + static_messages: list[LLMNodeChatModelMessage], + query: str | None, + files: Sequence[File], + context: str | None, + memory: BaseMemory | None, + model_config: ModelConfigWithCredentialsEntity, + context_files: list[File], + ) -> tuple[list[PromptMessage], Sequence[str] | None]: + """ + Build prompt messages by combining static messages and context references in DSL order. + + Returns: + Tuple of (prompt_messages, stop_sequences) + """ + variable_pool = self.graph_runtime_state.variable_pool + + # Build a map from context index to its messages + context_messages_map: dict[int, list[PromptMessage]] = {} + context_idx = 0 + for idx, type_ in template_order: + if type_ == "context": + ctx_ref = context_refs[context_idx] + ctx_var = variable_pool.get(ctx_ref.value_selector) + if ctx_var is None: + raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found") + if not isinstance(ctx_var, ArrayPromptMessageSegment): + raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]") + context_messages_map[idx] = list(ctx_var.value) + context_idx += 1 + + # Process static messages + static_prompt_messages: Sequence[PromptMessage] = [] + stop: Sequence[str] | None = None + if static_messages: + static_prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query=query, + sys_files=files, + context=context, + memory=memory, + model_config=model_config, + prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template), + memory_config=self.node_data.memory, + vision_enabled=self.node_data.vision.enabled, + vision_detail=self.node_data.vision.configs.detail, + variable_pool=variable_pool, + jinja2_variables=self.node_data.prompt_config.jinja2_variables, + tenant_id=self.tenant_id, + context_files=context_files, + ) + + # Combine messages according to original DSL order + combined_messages: list[PromptMessage] = [] + static_msg_iter = iter(static_prompt_messages) + for idx, type_ in template_order: + if type_ == "context": + combined_messages.extend(context_messages_map[idx]) + else: + if msg := next(static_msg_iter, None): + combined_messages.append(msg) + # Append any remaining static messages (e.g., memory messages) + combined_messages.extend(static_msg_iter) + + return combined_messages, stop + def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: variables: dict[str, Any] = {} diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index c1cfbb1edc..72e71b020b 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Any, Literal, Union from pydantic import BaseModel, field_validator @@ -7,6 +8,31 @@ from core.tools.entities.tool_entities import ToolProviderType from core.workflow.nodes.base.entities import BaseNodeData +class MentionValue(BaseModel): + """Value structure for mention type parameters. + + Used when a tool parameter needs to be extracted from conversation context + using an extractor LLM node. + """ + + # Variable selector for list[PromptMessage] input to extractor + variable_selector: Sequence[str] + + # ID of the extractor LLM node + extractor_node_id: str + + # Output variable selector from extractor node + # e.g., ["text"], ["structured_output", "query"] + output_selector: Sequence[str] + + # Strategy when output is None + null_strategy: Literal["raise_error", "use_default"] = "raise_error" + + # Default value when null_strategy is "use_default" + # Type should match the parameter's expected type + default_value: Any = None + + class ToolEntity(BaseModel): provider_id: str provider_type: ToolProviderType @@ -34,8 +60,8 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] + value: Union[Any, list[str], MentionValue] + type: Literal["mixed", "variable", "constant", "mention"] @field_validator("type", mode="before") @classmethod @@ -56,6 +82,17 @@ class ToolNodeData(BaseNodeData, ToolEntity): raise ValueError("value must be a list of strings") elif typ == "constant" and not isinstance(value, str | int | float | bool | dict): raise ValueError("value must be a string, int, float, bool or dict") + elif typ == "mention": + # Mention type: value should be a MentionValue or dict with required fields + if isinstance(value, MentionValue): + pass # Already validated by Pydantic + elif isinstance(value, dict): + if "extractor_node_id" not in value: + raise ValueError("value must contain extractor_node_id for mention type") + if "output_selector" not in value: + raise ValueError("value must contain output_selector for mention type") + else: + raise ValueError("value must be a MentionValue or dict for mention type") return typ tool_parameters: dict[str, ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 0ba58a9560..7752dc0f46 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,7 +1,10 @@ +import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any from sqlalchemy import select + +logger = logging.getLogger(__name__) from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler @@ -89,20 +92,18 @@ class ToolNode(Node[ToolNodeData]): ) return - # get parameters (use virtual_node_outputs from base class) + # get parameters tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, - virtual_node_outputs=self.virtual_node_outputs, ) parameters_for_log = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, for_log=True, - virtual_node_outputs=self.virtual_node_outputs, ) # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) @@ -178,7 +179,6 @@ class ToolNode(Node[ToolNodeData]): variable_pool: "VariablePool", node_data: ToolNodeData, for_log: bool = False, - virtual_node_outputs: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Generate parameters based on the given tool parameters, variable pool, and node data. @@ -188,16 +188,12 @@ class ToolNode(Node[ToolNodeData]): variable_pool (VariablePool): The variable pool containing the variables. node_data (ToolNodeData): The data associated with the tool node. for_log (bool): Whether to generate parameters for logging. - virtual_node_outputs (dict[str, Any] | None): Outputs from virtual sub-nodes. - Maps local_id -> outputs dict. Virtual node outputs are also in variable_pool - with global IDs like "{parent_id}.{local_id}". Returns: Mapping[str, Any]: A dictionary containing the generated parameters. """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - virtual_node_outputs = virtual_node_outputs or {} result: dict[str, Any] = {} for parameter_name in node_data.tool_parameters: @@ -207,22 +203,39 @@ class ToolNode(Node[ToolNodeData]): continue tool_input = node_data.tool_parameters[parameter_name] if tool_input.type == "variable": - # Check if this references a virtual node output (local ID like [ext_1, text]) + if not isinstance(tool_input.value, list): + raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'") selector = tool_input.value - if len(selector) >= 2 and selector[0] in virtual_node_outputs: - # Reference to virtual node output - local_id = selector[0] - var_name = selector[1] - outputs = virtual_node_outputs.get(local_id, {}) - parameter_value = outputs.get(var_name) + variable = variable_pool.get(selector) + if variable is None: + if parameter.required: + raise ToolParameterError(f"Variable {selector} does not exist") + continue + parameter_value = variable.value + elif tool_input.type == "mention": + # Mention type: get value from extractor node's output + from .entities import MentionValue + + mention_value = tool_input.value + if isinstance(mention_value, MentionValue): + mention_config = mention_value.model_dump() + elif isinstance(mention_value, dict): + mention_config = mention_value else: - # Normal variable reference - variable = variable_pool.get(selector) - if variable is None: - if parameter.required: - raise ToolParameterError(f"Variable {selector} does not exist") + raise ToolParameterError(f"Invalid mention value for parameter '{parameter_name}'") + + try: + parameter_value, found = variable_pool.resolve_mention( + mention_config, parameter_name=parameter_name + ) + if not found and parameter.required: + raise ToolParameterError( + f"Extractor output not found for required parameter '{parameter_name}'" + ) + if not found: continue - parameter_value = variable.value + except ValueError as e: + raise ToolParameterError(str(e)) from e elif tool_input.type in {"mixed", "constant"}: template = str(tool_input.value) segment_group = variable_pool.convert_template(template) @@ -507,8 +520,12 @@ class ToolNode(Node[ToolNodeData]): for selector in selectors: result[selector.variable] = selector.value_selector elif input.type == "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value + if isinstance(input.value, list): + selector_key = ".".join(input.value) + result[f"#{selector_key}#"] = input.value + elif input.type == "mention": + # Mention type handled by extractor node, no direct variable reference + pass elif input.type == "constant": pass diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 85ceb9d59e..f456f61dd0 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -268,6 +268,58 @@ class VariablePool(BaseModel): continue self.add(selector, value) + def resolve_mention( + self, + mention_config: Mapping[str, Any], + /, + *, + parameter_name: str = "", + ) -> tuple[Any, bool]: + """ + Resolve a mention parameter value from an extractor node's output. + + Mention parameters reference values extracted by an extractor LLM node + from list[PromptMessage] context. + + Args: + mention_config: A dict containing: + - extractor_node_id: ID of the extractor LLM node + - output_selector: Selector path for the output variable (e.g., ["text"]) + - null_strategy: "raise_error" or "use_default" + - default_value: Value to use when null_strategy is "use_default" + parameter_name: Name of the parameter being resolved (for error messages) + + Returns: + Tuple of (resolved_value, found): + - resolved_value: The extracted value, or default_value if not found + - found: True if value was found, False if using default + + Raises: + ValueError: If extractor_node_id is missing, or if null_strategy is + "raise_error" and the value is not found + """ + extractor_node_id = mention_config.get("extractor_node_id") + if not extractor_node_id: + raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'") + + output_selector = list(mention_config.get("output_selector", [])) + null_strategy = mention_config.get("null_strategy", "raise_error") + default_value = mention_config.get("default_value") + + # Build full selector: [extractor_node_id, ...output_selector] + full_selector = [extractor_node_id] + output_selector + variable = self.get(full_selector) + + if variable is None: + if null_strategy == "use_default": + return default_value, False + raise ValueError( + f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' " + f"not found for parameter '{parameter_name}'" + ) + + return variable.value, True + @classmethod def empty(cls) -> VariablePool: """Create an empty variable pool.""" diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 494194369a..fb697a8c29 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -4,6 +4,7 @@ from uuid import uuid4 from configs import dify_config from core.file import File +from core.model_runtime.entities import PromptMessage from core.variables.exc import VariableError from core.variables.segments import ( ArrayAnySegment, @@ -11,6 +12,7 @@ from core.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, + ArrayPromptMessageSegment, ArraySegment, ArrayStringSegment, BooleanSegment, @@ -29,6 +31,7 @@ from core.variables.variables import ( ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, + ArrayPromptMessageVariable, ArrayStringVariable, BooleanVariable, FileVariable, @@ -61,6 +64,7 @@ SEGMENT_TO_VARIABLE_MAP = { ArrayFileSegment: ArrayFileVariable, ArrayNumberSegment: ArrayNumberVariable, ArrayObjectSegment: ArrayObjectVariable, + ArrayPromptMessageSegment: ArrayPromptMessageVariable, ArrayStringSegment: ArrayStringVariable, BooleanSegment: BooleanVariable, FileSegment: FileVariable, @@ -156,7 +160,13 @@ def build_segment(value: Any, /) -> Segment: return ObjectSegment(value=value) if isinstance(value, File): return FileSegment(value=value) + if isinstance(value, PromptMessage): + # Single PromptMessage should be wrapped in a list + return ArrayPromptMessageSegment(value=[value]) if isinstance(value, list): + # Check if all items are PromptMessage + if value and all(isinstance(item, PromptMessage) for item in value): + return ArrayPromptMessageSegment(value=value) items = [build_segment(item) for item in value] types = {item.value_type for item in items} if all(isinstance(item, ArraySegment) for item in items): @@ -200,6 +210,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = { SegmentType.ARRAY_OBJECT: ArrayObjectSegment, SegmentType.ARRAY_FILE: ArrayFileSegment, SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, + SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment, } diff --git a/api/models/workflow.py b/api/models/workflow.py index 072c6100b5..7be51c05b6 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1291,7 +1291,7 @@ class WorkflowDraftVariable(Base): # which may differ from the original value's type. Typically, they are the same, # but in cases where the structurally truncated value still exceeds the size limit, # text slicing is applied, and the `value_type` is converted to `STRING`. - value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20)) + value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21)) # The variable's value serialized as a JSON string # @@ -1665,7 +1665,7 @@ class WorkflowDraftVariableFile(Base): # The `value_type` field records the type of the original value. value_type: Mapped[SegmentType] = mapped_column( - EnumText(SegmentType, length=20), + EnumText(SegmentType, length=21), nullable=False, ) diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index f973361341..9d587c7850 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -7,6 +7,7 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload from configs import dify_config from core.file.models import File +from core.model_runtime.entities import PromptMessage from core.variables.segments import ( ArrayFileSegment, ArraySegment, @@ -287,6 +288,10 @@ class VariableTruncator(BaseTruncator): if isinstance(item, File): truncated_value.append(item) continue + # Handle PromptMessage types - convert to dict for truncation + if isinstance(item, PromptMessage): + truncated_value.append(item) + continue if i >= target_length: return _PartResult(truncated_value, used_size, True) if i > 0: diff --git a/api/tests/unit_tests/core/workflow/entities/test_virtual_node.py b/api/tests/unit_tests/core/workflow/entities/test_virtual_node.py deleted file mode 100644 index ffffccfa1b..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_virtual_node.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Unit tests for virtual node configuration. -""" - -from core.workflow.nodes.base.entities import VirtualNodeConfig - - -class TestVirtualNodeConfig: - """Tests for VirtualNodeConfig entity.""" - - def test_create_basic_config(self): - """Test creating a basic virtual node config.""" - config = VirtualNodeConfig( - id="ext_1", - type="llm", - data={ - "title": "Extract keywords", - "model": {"provider": "openai", "name": "gpt-4o-mini"}, - }, - ) - - assert config.id == "ext_1" - assert config.type == "llm" - assert config.data["title"] == "Extract keywords" - - def test_get_global_id(self): - """Test generating global ID from parent ID.""" - config = VirtualNodeConfig( - id="ext_1", - type="llm", - data={}, - ) - - global_id = config.get_global_id("tool1") - assert global_id == "tool1.ext_1" - - def test_get_global_id_with_different_parents(self): - """Test global ID generation with different parent IDs.""" - config = VirtualNodeConfig(id="sub_node", type="code", data={}) - - assert config.get_global_id("parent1") == "parent1.sub_node" - assert config.get_global_id("node_123") == "node_123.sub_node" - - def test_empty_data(self): - """Test virtual node config with empty data.""" - config = VirtualNodeConfig( - id="test", - type="tool", - ) - - assert config.id == "test" - assert config.type == "tool" - assert config.data == {} - - def test_complex_data(self): - """Test virtual node config with complex data.""" - config = VirtualNodeConfig( - id="llm_1", - type="llm", - data={ - "title": "Generate summary", - "model": { - "provider": "openai", - "name": "gpt-4", - "mode": "chat", - "completion_params": {"temperature": 0.7, "max_tokens": 500}, - }, - "prompt_template": [ - {"role": "user", "text": "{{#llm1.context#}}"}, - {"role": "user", "text": "Please summarize the conversation"}, - ], - }, - ) - - assert config.data["model"]["provider"] == "openai" - assert len(config.data["prompt_template"]) == 2 - diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 5d17b7a243..65bd3d87d4 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -25,6 +25,12 @@ class _StubErrorHandler: """Minimal error handler stub for tests.""" +class _StubNodeData: + """Simple node data stub with is_extractor_node property.""" + + is_extractor_node = False + + class _StubNode: """Simple node stub exposing the attributes needed by the state manager.""" @@ -36,6 +42,7 @@ class _StubNode: self.error_strategy = None self.retry_config = RetryConfig() self.retry = False + self.node_data = _StubNodeData() def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: