diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index f61ddb464a..05b668b803 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,6 +3,7 @@ import logging from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx.inputs import int_range +from sqlalchemy import exists, select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api @@ -94,21 +95,18 @@ class ChatMessageListApi(Resource): .all() ) - has_more = False if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] - rest_count = ( - db.session.query(Message) - .where( + + has_more = db.session.scalar( + select( + exists().where( Message.conversation_id == conversation.id, Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id, ) - .count() ) - - if rest_count > 0: - has_more = True + ) history_messages = list(reversed(history_messages)) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ac12d83ef2..8357dac0d7 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -8,20 +8,21 @@ from uuid import UUID import numpy as np import pytz -from flask_login import current_user from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager +from libs.login import current_user +from models.account import Account logger = logging.getLogger(__name__) def safe_json_value(v): if isinstance(v, datetime): - tz_name = getattr(current_user, "timezone", None) if current_user is not None else None - if not tz_name: - tz_name = "UTC" + tz_name = "UTC" + if isinstance(current_user, Account) and current_user.timezone is not None: + tz_name = current_user.timezone return v.astimezone(pytz.timezone(tz_name)).isoformat() elif isinstance(v, date): return v.isoformat() @@ -46,7 +47,7 @@ def safe_json_value(v): return v -def safe_json_dict(d): +def safe_json_dict(d: dict): if not isinstance(d, dict): raise TypeError("safe_json_dict() expects a dictionary (dict) as input") return {k: safe_json_value(v) for k, v in d.items()} diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index b4c66ba27d..d8749f9851 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -3,8 +3,6 @@ import logging from collections.abc import Generator from typing import Any, Optional, cast -from flask_login import current_user - from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -17,8 +15,8 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from extensions.ext_database import db from factories.file_factory import build_from_mapping -from models.account import Account -from models.model import App, EndUser +from libs.login import current_user +from models.model import App from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -79,11 +77,11 @@ class WorkflowTool(Tool): generator = WorkflowAppGenerator() assert self.runtime is not None assert self.runtime.invoke_from is not None - + assert current_user is not None result = generator.generate( app_model=app, workflow=workflow, - user=cast("Account | EndUser", current_user), + user=current_user, args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, streaming=False, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index f04f6ccc55..c5be9be02a 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -66,6 +66,7 @@ class NodeExecutionType(StrEnum): RESPONSE = "response" # Response nodes that stream outputs (Answer, End) BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier) CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph) + ROOT = "root" # Nodes that can serve as execution entry points class ErrorStrategy(StrEnum): diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 5bb02c8a7f..8372270a11 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -1,9 +1,9 @@ import logging from collections import defaultdict from collections.abc import Mapping -from typing import Any, Optional, Protocol, cast +from typing import Any, Protocol, cast -from core.workflow.enums import NodeType +from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node from .edge import Edge @@ -36,10 +36,10 @@ class Graph: def __init__( self, *, - nodes: Optional[dict[str, Node]] = None, - edges: Optional[dict[str, Edge]] = None, - in_edges: Optional[dict[str, list[str]]] = None, - out_edges: Optional[dict[str, list[str]]] = None, + nodes: dict[str, Node] | None = None, + edges: dict[str, Edge] | None = None, + in_edges: dict[str, list[str]] | None = None, + out_edges: dict[str, list[str]] | None = None, root_node: Node, ): """ @@ -81,7 +81,7 @@ class Graph: cls, node_configs_map: dict[str, dict[str, Any]], edge_configs: list[dict[str, Any]], - root_node_id: Optional[str] = None, + root_node_id: str | None = None, ) -> str: """ Find the root node ID if not specified. @@ -186,13 +186,79 @@ class Graph: return nodes + @classmethod + def _mark_inactive_root_branches( + cls, + nodes: dict[str, Node], + edges: dict[str, Edge], + in_edges: dict[str, list[str]], + out_edges: dict[str, list[str]], + active_root_id: str, + ) -> None: + """ + Mark nodes and edges from inactive root branches as skipped. + + Algorithm: + 1. Mark inactive root nodes as skipped + 2. For skipped nodes, mark all their outgoing edges as skipped + 3. For each edge marked as skipped, check its target node: + - If ALL incoming edges are skipped, mark the node as skipped + - Otherwise, leave the node state unchanged + + :param nodes: mapping of node ID to node instance + :param edges: mapping of edge ID to edge instance + :param in_edges: mapping of node ID to incoming edge IDs + :param out_edges: mapping of node ID to outgoing edge IDs + :param active_root_id: ID of the active root node + """ + # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges) + top_level_roots: list[str] = [ + node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT + ] + + # If there's only one root or the active root is not a top-level root, no marking needed + if len(top_level_roots) <= 1 or active_root_id not in top_level_roots: + return + + # Mark inactive root nodes as skipped + inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id] + for root_id in inactive_roots: + if root_id in nodes: + nodes[root_id].state = NodeState.SKIPPED + + # Recursively mark downstream nodes and edges + def mark_downstream(node_id: str) -> None: + """Recursively mark downstream nodes and edges as skipped.""" + if nodes[node_id].state != NodeState.SKIPPED: + return + # If this node is skipped, mark all its outgoing edges as skipped + out_edge_ids = out_edges.get(node_id, []) + for edge_id in out_edge_ids: + edge = edges[edge_id] + edge.state = NodeState.SKIPPED + + # Check the target node of this edge + target_node = nodes[edge.head] + in_edge_ids = in_edges.get(target_node.id, []) + in_edge_states = [edges[eid].state for eid in in_edge_ids] + + # If all incoming edges are skipped, mark the node as skipped + if all(state == NodeState.SKIPPED for state in in_edge_states): + target_node.state = NodeState.SKIPPED + # Recursively process downstream nodes + mark_downstream(target_node.id) + + # Process each inactive root and its downstream nodes + for root_id in inactive_roots: + mark_downstream(root_id) + @classmethod def init( cls, *, graph_config: Mapping[str, Any], node_factory: "NodeFactory", - root_node_id: Optional[str] = None, + root_node_id: str | None = None, ) -> "Graph": """ Initialize graph @@ -227,6 +293,9 @@ class Graph: # Get root node instance root_node = nodes[root_node_id] + # Mark inactive root branches as skipped + cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) + # Create and return the graph return cls( nodes=nodes, diff --git a/api/core/workflow/graph_engine/command_channels/in_memory_channel.py b/api/core/workflow/graph_engine/command_channels/in_memory_channel.py index ef498e6890..bdaf236796 100644 --- a/api/core/workflow/graph_engine/command_channels/in_memory_channel.py +++ b/api/core/workflow/graph_engine/command_channels/in_memory_channel.py @@ -6,10 +6,12 @@ within a single process. Each instance handles commands for one workflow executi """ from queue import Queue +from typing import final from ..entities.commands import GraphEngineCommand +@final class InMemoryChannel: """ In-memory command channel implementation using a thread-safe queue. diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 6feb8b8a25..7809e43e32 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue. """ import json -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, final from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand @@ -15,6 +15,7 @@ if TYPE_CHECKING: from extensions.ext_redis import RedisClientWrapper +@final class RedisChannel: """ Redis-based command channel implementation for distributed systems. @@ -86,7 +87,7 @@ class RedisChannel: pipe.expire(self._key, self._command_ttl) pipe.execute() - def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]: + def _deserialize_command(self, data: dict) -> GraphEngineCommand | None: """ Deserialize a command from dictionary data. diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index f8bae5e21a..9f8d20b1b9 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -3,6 +3,7 @@ Command handler implementations. """ import logging +from typing import final from ..domain.graph_execution import GraphExecution from ..entities.commands import AbortCommand, GraphEngineCommand @@ -11,6 +12,7 @@ from .command_processor import CommandHandler logger = logging.getLogger(__name__) +@final class AbortCommandHandler(CommandHandler): """Handles abort commands.""" diff --git a/api/core/workflow/graph_engine/command_processing/command_processor.py b/api/core/workflow/graph_engine/command_processing/command_processor.py index 06b3a8d8a4..2521058ef2 100644 --- a/api/core/workflow/graph_engine/command_processing/command_processor.py +++ b/api/core/workflow/graph_engine/command_processing/command_processor.py @@ -3,7 +3,7 @@ Main command processor for handling external commands. """ import logging -from typing import Protocol +from typing import Protocol, final from ..domain.graph_execution import GraphExecution from ..entities.commands import GraphEngineCommand @@ -18,6 +18,7 @@ class CommandHandler(Protocol): def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ... +@final class CommandProcessor: """ Processes external commands sent to the engine. diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py index b8fa801289..c375b08fe0 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/core/workflow/graph_engine/domain/graph_execution.py @@ -3,7 +3,6 @@ GraphExecution aggregate root managing the overall graph execution state. """ from dataclasses import dataclass, field -from typing import Optional from .node_execution import NodeExecution @@ -21,7 +20,7 @@ class GraphExecution: started: bool = False completed: bool = False aborted: bool = False - error: Optional[Exception] = None + error: Exception | None = None node_executions: dict[str, NodeExecution] = field(default_factory=dict) def start(self) -> None: diff --git a/api/core/workflow/graph_engine/domain/node_execution.py b/api/core/workflow/graph_engine/domain/node_execution.py index 937ae0fb93..85700caa3a 100644 --- a/api/core/workflow/graph_engine/domain/node_execution.py +++ b/api/core/workflow/graph_engine/domain/node_execution.py @@ -3,7 +3,6 @@ NodeExecution entity representing a node's execution state. """ from dataclasses import dataclass -from typing import Optional from core.workflow.enums import NodeState @@ -20,8 +19,8 @@ class NodeExecution: node_id: str state: NodeState = NodeState.UNKNOWN retry_count: int = 0 - execution_id: Optional[str] = None - error: Optional[str] = None + execution_id: str | None = None + error: str | None = None def mark_started(self, execution_id: str) -> None: """Mark the node as started with an execution ID.""" diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index a92ebf512d..7e25fc0866 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -6,7 +6,7 @@ instance to control its execution flow. """ from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -23,11 +23,11 @@ class GraphEngineCommand(BaseModel): """Base class for all GraphEngine commands.""" command_type: CommandType = Field(..., description="Type of command") - payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload") + payload: dict[str, Any] | None = Field(default=None, description="Optional command payload") class AbortCommand(GraphEngineCommand): """Command to abort a running workflow execution.""" command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") - reason: Optional[str] = Field(default=None, description="Optional reason for abort") + reason: str | None = Field(default=None, description="Optional reason for abort") diff --git a/api/core/workflow/graph_engine/error_handling/__init__.py b/api/core/workflow/graph_engine/error_handling/__init__.py index 4c865e58fc..1316710d0d 100644 --- a/api/core/workflow/graph_engine/error_handling/__init__.py +++ b/api/core/workflow/graph_engine/error_handling/__init__.py @@ -8,7 +8,6 @@ the Strategy pattern for clean separation of concerns. from .abort_strategy import AbortStrategy from .default_value_strategy import DefaultValueStrategy from .error_handler import ErrorHandler -from .error_strategy import ErrorStrategy from .fail_branch_strategy import FailBranchStrategy from .retry_strategy import RetryStrategy @@ -16,7 +15,6 @@ __all__ = [ "AbortStrategy", "DefaultValueStrategy", "ErrorHandler", - "ErrorStrategy", "FailBranchStrategy", "RetryStrategy", ] diff --git a/api/core/workflow/graph_engine/error_handling/abort_strategy.py b/api/core/workflow/graph_engine/error_handling/abort_strategy.py index e747704fda..6a805bd124 100644 --- a/api/core/workflow/graph_engine/error_handling/abort_strategy.py +++ b/api/core/workflow/graph_engine/error_handling/abort_strategy.py @@ -3,7 +3,7 @@ Abort error strategy implementation. """ import logging -from typing import Optional +from typing import final from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent @@ -11,6 +11,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent logger = logging.getLogger(__name__) +@final class AbortStrategy: """ Error strategy that aborts execution on failure. @@ -19,7 +20,7 @@ class AbortStrategy: It stops the entire graph execution when a node fails. """ - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: + def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: """ Handle error by aborting execution. diff --git a/api/core/workflow/graph_engine/error_handling/default_value_strategy.py b/api/core/workflow/graph_engine/error_handling/default_value_strategy.py index 92e61dc22a..61d36399aa 100644 --- a/api/core/workflow/graph_engine/error_handling/default_value_strategy.py +++ b/api/core/workflow/graph_engine/error_handling/default_value_strategy.py @@ -2,7 +2,7 @@ Default value error strategy implementation. """ -from typing import Optional +from typing import final from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.graph import Graph @@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent from core.workflow.node_events import NodeRunResult +@final class DefaultValueStrategy: """ Error strategy that uses default values on failure. @@ -18,7 +19,7 @@ class DefaultValueStrategy: predefined default output values. """ - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: + def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: """ Handle error by using default values. diff --git a/api/core/workflow/graph_engine/error_handling/error_handler.py b/api/core/workflow/graph_engine/error_handling/error_handler.py index 7f6abb146c..b51d7e4dad 100644 --- a/api/core/workflow/graph_engine/error_handling/error_handler.py +++ b/api/core/workflow/graph_engine/error_handling/error_handler.py @@ -2,7 +2,7 @@ Main error handler that coordinates error strategies. """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, final from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum from core.workflow.graph import Graph @@ -17,6 +17,7 @@ if TYPE_CHECKING: from ..domain import GraphExecution +@final class ErrorHandler: """ Coordinates error handling strategies for node failures. @@ -34,16 +35,16 @@ class ErrorHandler: graph: The workflow graph graph_execution: The graph execution state """ - self.graph = graph - self.graph_execution = graph_execution + self._graph = graph + self._graph_execution = graph_execution # Initialize strategies - self.abort_strategy = AbortStrategy() - self.retry_strategy = RetryStrategy() - self.fail_branch_strategy = FailBranchStrategy() - self.default_value_strategy = DefaultValueStrategy() + self._abort_strategy = AbortStrategy() + self._retry_strategy = RetryStrategy() + self._fail_branch_strategy = FailBranchStrategy() + self._default_value_strategy = DefaultValueStrategy() - def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]: + def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: """ Handle a node failure event. @@ -56,14 +57,14 @@ class ErrorHandler: Returns: Optional new event to process, or None to abort """ - node = self.graph.nodes[event.node_id] + node = self._graph.nodes[event.node_id] # Get retry count from NodeExecution - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) retry_count = node_execution.retry_count # First check if retry is configured and not exhausted if node.retry and retry_count < node.retry_config.max_retries: - result = self.retry_strategy.handle_error(event, self.graph, retry_count) + result = self._retry_strategy.handle_error(event, self._graph, retry_count) if result: # Retry count will be incremented when NodeRunRetryEvent is handled return result @@ -71,12 +72,10 @@ class ErrorHandler: # Apply configured error strategy strategy = node.error_strategy - if strategy is None: - return self.abort_strategy.handle_error(event, self.graph, retry_count) - elif strategy == ErrorStrategyEnum.FAIL_BRANCH: - return self.fail_branch_strategy.handle_error(event, self.graph, retry_count) - elif strategy == ErrorStrategyEnum.DEFAULT_VALUE: - return self.default_value_strategy.handle_error(event, self.graph, retry_count) - else: - # Unknown strategy, default to abort - return self.abort_strategy.handle_error(event, self.graph, retry_count) + match strategy: + case None: + return self._abort_strategy.handle_error(event, self._graph, retry_count) + case ErrorStrategyEnum.FAIL_BRANCH: + return self._fail_branch_strategy.handle_error(event, self._graph, retry_count) + case ErrorStrategyEnum.DEFAULT_VALUE: + return self._default_value_strategy.handle_error(event, self._graph, retry_count) diff --git a/api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py b/api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py index 82e434c89b..437c2bc7da 100644 --- a/api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py +++ b/api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py @@ -2,7 +2,7 @@ Fail branch error strategy implementation. """ -from typing import Optional +from typing import final from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.graph import Graph @@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent from core.workflow.node_events import NodeRunResult +@final class FailBranchStrategy: """ Error strategy that continues execution via a fail branch. @@ -18,7 +19,7 @@ class FailBranchStrategy: through a designated fail-branch edge. """ - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: + def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: """ Handle error by taking the fail branch. diff --git a/api/core/workflow/graph_engine/error_handling/retry_strategy.py b/api/core/workflow/graph_engine/error_handling/retry_strategy.py index 5956a7c62e..e4010b6bdb 100644 --- a/api/core/workflow/graph_engine/error_handling/retry_strategy.py +++ b/api/core/workflow/graph_engine/error_handling/retry_strategy.py @@ -3,12 +3,13 @@ Retry error strategy implementation. """ import time -from typing import Optional +from typing import final from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent +@final class RetryStrategy: """ Error strategy that retries failed nodes. @@ -17,7 +18,7 @@ class RetryStrategy: maximum number of retries with configurable intervals. """ - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: + def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: """ Handle error by retrying the node. diff --git a/api/core/workflow/graph_engine/event_management/event_collector.py b/api/core/workflow/graph_engine/event_management/event_collector.py index 3d266fc012..a41dcf5b10 100644 --- a/api/core/workflow/graph_engine/event_management/event_collector.py +++ b/api/core/workflow/graph_engine/event_management/event_collector.py @@ -3,12 +3,92 @@ Event collector for buffering and managing events. """ import threading +from typing import final from core.workflow.graph_events import GraphEngineEvent from ..layers.base import Layer +@final +class ReadWriteLock: + """ + A read-write lock implementation that allows multiple concurrent readers + but only one writer at a time. + """ + + def __init__(self) -> None: + self._read_ready = threading.Condition(threading.RLock()) + self._readers = 0 + + def acquire_read(self) -> None: + """Acquire a read lock.""" + self._read_ready.acquire() + try: + self._readers += 1 + finally: + self._read_ready.release() + + def release_read(self) -> None: + """Release a read lock.""" + self._read_ready.acquire() + try: + self._readers -= 1 + if self._readers == 0: + self._read_ready.notify_all() + finally: + self._read_ready.release() + + def acquire_write(self) -> None: + """Acquire a write lock.""" + self._read_ready.acquire() + while self._readers > 0: + self._read_ready.wait() + + def release_write(self) -> None: + """Release a write lock.""" + self._read_ready.release() + + def read_lock(self) -> "ReadLockContext": + """Return a context manager for read locking.""" + return ReadLockContext(self) + + def write_lock(self) -> "WriteLockContext": + """Return a context manager for write locking.""" + return WriteLockContext(self) + + +@final +class ReadLockContext: + """Context manager for read locks.""" + + def __init__(self, lock: ReadWriteLock) -> None: + self._lock = lock + + def __enter__(self) -> "ReadLockContext": + self._lock.acquire_read() + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: + self._lock.release_read() + + +@final +class WriteLockContext: + """Context manager for write locks.""" + + def __init__(self, lock: ReadWriteLock) -> None: + self._lock = lock + + def __enter__(self) -> "WriteLockContext": + self._lock.acquire_write() + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None: + self._lock.release_write() + + +@final class EventCollector: """ Collects and buffers events for later retrieval. @@ -20,7 +100,7 @@ class EventCollector: def __init__(self) -> None: """Initialize the event collector.""" self._events: list[GraphEngineEvent] = [] - self._lock = threading.Lock() + self._lock = ReadWriteLock() self._layers: list[Layer] = [] def set_layers(self, layers: list[Layer]) -> None: @@ -39,7 +119,7 @@ class EventCollector: Args: event: The event to collect """ - with self._lock: + with self._lock.write_lock(): self._events.append(event) self._notify_layers(event) @@ -50,7 +130,7 @@ class EventCollector: Returns: List of collected events """ - with self._lock: + with self._lock.read_lock(): return list(self._events) def get_new_events(self, start_index: int) -> list[GraphEngineEvent]: @@ -63,7 +143,7 @@ class EventCollector: Returns: List of new events """ - with self._lock: + with self._lock.read_lock(): return list(self._events[start_index:]) def event_count(self) -> int: @@ -73,12 +153,12 @@ class EventCollector: Returns: Number of collected events """ - with self._lock: + with self._lock.read_lock(): return len(self._events) def clear(self) -> None: """Clear all collected events.""" - with self._lock: + with self._lock.write_lock(): self._events.clear() def _notify_layers(self, event: GraphEngineEvent) -> None: diff --git a/api/core/workflow/graph_engine/event_management/event_emitter.py b/api/core/workflow/graph_engine/event_management/event_emitter.py index 36f9d5d5a2..6fb0b96e8c 100644 --- a/api/core/workflow/graph_engine/event_management/event_emitter.py +++ b/api/core/workflow/graph_engine/event_management/event_emitter.py @@ -5,12 +5,14 @@ Event emitter for yielding events to external consumers. import threading import time from collections.abc import Generator +from typing import final from core.workflow.graph_events import GraphEngineEvent from .event_collector import EventCollector +@final class EventEmitter: """ Emits collected events as a generator for external consumption. diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index db3137e99a..842bd2635f 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -3,7 +3,7 @@ Event handler implementations for different event types. """ import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, final from core.workflow.entities import GraphRuntimeState from core.workflow.enums import NodeExecutionType @@ -38,6 +38,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@final class EventHandlerRegistry: """ Registry of event handlers for different event types. @@ -52,12 +53,12 @@ class EventHandlerRegistry: graph_runtime_state: GraphRuntimeState, graph_execution: GraphExecution, response_coordinator: ResponseStreamCoordinator, - event_collector: Optional["EventCollector"] = None, - branch_handler: Optional["BranchHandler"] = None, - edge_processor: Optional["EdgeProcessor"] = None, - node_state_manager: Optional["NodeStateManager"] = None, - execution_tracker: Optional["ExecutionTracker"] = None, - error_handler: Optional["ErrorHandler"] = None, + event_collector: "EventCollector", + branch_handler: "BranchHandler", + edge_processor: "EdgeProcessor", + node_state_manager: "NodeStateManager", + execution_tracker: "ExecutionTracker", + error_handler: "ErrorHandler", ) -> None: """ Initialize the event handler registry. @@ -67,23 +68,23 @@ class EventHandlerRegistry: graph_runtime_state: Runtime state with variable pool graph_execution: Graph execution aggregate response_coordinator: Response stream coordinator - event_collector: Optional event collector for collecting events - branch_handler: Optional branch handler for branch node processing - edge_processor: Optional edge processor for edge traversal - node_state_manager: Optional node state manager - execution_tracker: Optional execution tracker - error_handler: Optional error handler + event_collector: Event collector for collecting events + branch_handler: Branch handler for branch node processing + edge_processor: Edge processor for edge traversal + node_state_manager: Node state manager + execution_tracker: Execution tracker + error_handler: Error handler """ - self.graph = graph - self.graph_runtime_state = graph_runtime_state - self.graph_execution = graph_execution - self.response_coordinator = response_coordinator - self.event_collector = event_collector - self.branch_handler = branch_handler - self.edge_processor = edge_processor - self.node_state_manager = node_state_manager - self.execution_tracker = execution_tracker - self.error_handler = error_handler + self._graph = graph + self._graph_runtime_state = graph_runtime_state + self._graph_execution = graph_execution + self._response_coordinator = response_coordinator + self._event_collector = event_collector + self._branch_handler = branch_handler + self._edge_processor = edge_processor + self._node_state_manager = node_state_manager + self._execution_tracker = execution_tracker + self._error_handler = error_handler def handle_event(self, event: GraphNodeEventBase) -> None: """ @@ -93,9 +94,8 @@ class EventHandlerRegistry: event: The event to handle """ # Events in loops or iterations are always collected - if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id): - if self.event_collector: - self.event_collector.collect(event) + if event.in_loop_id or event.in_iteration_id: + self._event_collector.collect(event) return # Handle specific event types @@ -125,12 +125,10 @@ class EventHandlerRegistry: ), ): # Iteration and loop events are collected directly - if self.event_collector: - self.event_collector.collect(event) + self._event_collector.collect(event) else: # Collect unhandled events - if self.event_collector: - self.event_collector.collect(event) + self._event_collector.collect(event) logger.warning("Unhandled event type: %s", type(event).__name__) def _handle_node_started(self, event: NodeRunStartedEvent) -> None: @@ -141,15 +139,14 @@ class EventHandlerRegistry: event: The node started event """ # Track execution in domain model - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_started(event.id) # Track in response coordinator for stream ordering - self.response_coordinator.track_node_execution(event.node_id, event.id) + self._response_coordinator.track_node_execution(event.node_id, event.id) # Collect the event - if self.event_collector: - self.event_collector.collect(event) + self._event_collector.collect(event) def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None: """ @@ -159,12 +156,11 @@ class EventHandlerRegistry: event: The stream chunk event """ # Process with response coordinator - streaming_events = list(self.response_coordinator.intercept_event(event)) + streaming_events = list(self._response_coordinator.intercept_event(event)) # Collect all events - if self.event_collector: - for stream_event in streaming_events: - self.event_collector.collect(stream_event) + for stream_event in streaming_events: + self._event_collector.collect(stream_event) def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: """ @@ -177,55 +173,44 @@ class EventHandlerRegistry: event: The node succeeded event """ # Update domain model - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() # Store outputs in variable pool self._store_node_outputs(event) # Forward to response coordinator and emit streaming events - streaming_events = list(self.response_coordinator.intercept_event(event)) - if self.event_collector: - for stream_event in streaming_events: - self.event_collector.collect(stream_event) + streaming_events = self._response_coordinator.intercept_event(event) + for stream_event in streaming_events: + self._event_collector.collect(stream_event) # Process edges and get ready nodes - node = self.graph.nodes[event.node_id] + node = self._graph.nodes[event.node_id] if node.execution_type == NodeExecutionType.BRANCH: - if self.branch_handler: - ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - ready_nodes, edge_streaming_events = [], [] + ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion( + event.node_id, event.node_run_result.edge_source_handle + ) else: - if self.edge_processor: - ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id) - else: - ready_nodes, edge_streaming_events = [], [] + ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) # Collect streaming events from edge processing - if self.event_collector: - for edge_event in edge_streaming_events: - self.event_collector.collect(edge_event) + for edge_event in edge_streaming_events: + self._event_collector.collect(edge_event) # Enqueue ready nodes - if self.node_state_manager and self.execution_tracker: - for node_id in ready_nodes: - self.node_state_manager.enqueue_node(node_id) - self.execution_tracker.add(node_id) + for node_id in ready_nodes: + self._node_state_manager.enqueue_node(node_id) + self._execution_tracker.add(node_id) # Update execution tracking - if self.execution_tracker: - self.execution_tracker.remove(event.node_id) + self._execution_tracker.remove(event.node_id) # Handle response node outputs if node.execution_type == NodeExecutionType.RESPONSE: self._update_response_outputs(event) # Collect the event - if self.event_collector: - self.event_collector.collect(event) + self._event_collector.collect(event) def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: """ @@ -235,29 +220,19 @@ class EventHandlerRegistry: event: The node failed event """ # Update domain model - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_failed(event.error) - if self.error_handler: - result = self.error_handler.handle_node_failure(event) + result = self._error_handler.handle_node_failure(event) - if result: - # Process the resulting event (retry, exception, etc.) - self.handle_event(result) - else: - # Abort execution - self.graph_execution.fail(RuntimeError(event.error)) - if self.event_collector: - self.event_collector.collect(event) - if self.execution_tracker: - self.execution_tracker.remove(event.node_id) + if result: + # Process the resulting event (retry, exception, etc.) + self.handle_event(result) else: - # Without error handler, just fail - self.graph_execution.fail(RuntimeError(event.error)) - if self.event_collector: - self.event_collector.collect(event) - if self.execution_tracker: - self.execution_tracker.remove(event.node_id) + # Abort execution + self._graph_execution.fail(RuntimeError(event.error)) + self._event_collector.collect(event) + self._execution_tracker.remove(event.node_id) def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: """ @@ -267,7 +242,7 @@ class EventHandlerRegistry: event: The node exception event """ # Node continues via fail-branch, so it's technically "succeeded" - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() def _handle_node_retry(self, event: NodeRunRetryEvent) -> None: @@ -277,7 +252,7 @@ class EventHandlerRegistry: Args: event: The node retry event """ - node_execution = self.graph_execution.get_or_create_node_execution(event.node_id) + node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.increment_retry() def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None: @@ -288,16 +263,16 @@ class EventHandlerRegistry: event: The node succeeded event containing outputs """ for variable_name, variable_value in event.node_run_result.outputs.items(): - self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) + self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value) def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None: """Update response outputs for response nodes.""" for key, value in event.node_run_result.outputs.items(): if key == "answer": - existing = self.graph_runtime_state.outputs.get("answer", "") + existing = self._graph_runtime_state.outputs.get("answer", "") if existing: - self.graph_runtime_state.outputs["answer"] = f"{existing}{value}" + self._graph_runtime_state.outputs["answer"] = f"{existing}{value}" else: - self.graph_runtime_state.outputs["answer"] = value + self._graph_runtime_state.outputs["answer"] = value else: - self.graph_runtime_state.outputs[key] = value + self._graph_runtime_state.outputs[key] = value diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index dcea94b994..dd98536fba 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -9,7 +9,7 @@ import contextvars import logging import queue from collections.abc import Generator, Mapping -from typing import Any, Optional +from typing import final from flask import Flask, current_app @@ -20,6 +20,7 @@ from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph_events import ( GraphEngineEvent, + GraphNodeEventBase, GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunStartedEvent, @@ -44,6 +45,7 @@ from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, Wo logger = logging.getLogger(__name__) +@final class GraphEngine: """ Queue-based graph execution engine. @@ -62,7 +64,7 @@ class GraphEngine: invoke_from: InvokeFrom, call_depth: int, graph: Graph, - graph_config: Mapping[str, Any], + graph_config: Mapping[str, object], graph_runtime_state: GraphRuntimeState, max_execution_steps: int, max_execution_time: int, @@ -103,7 +105,7 @@ class GraphEngine: # Initialize queues self.ready_queue: queue.Queue[str] = queue.Queue() - self.event_queue: queue.Queue = queue.Queue() + self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() # Initialize subsystems self._initialize_subsystems() @@ -185,7 +187,7 @@ class GraphEngine: event_handler=self.event_handler_registry, event_collector=self.event_collector, command_processor=self.command_processor, - worker_pool=self.worker_pool, + worker_pool=self._worker_pool, ) self.dispatcher = Dispatcher( @@ -209,7 +211,7 @@ class GraphEngine: def _setup_worker_management(self) -> None: """Initialize worker management subsystem.""" # Capture context for workers - flask_app: Optional[Flask] = None + flask_app: Flask | None = None try: flask_app = current_app._get_current_object() # type: ignore except RuntimeError: @@ -218,8 +220,8 @@ class GraphEngine: context_vars = contextvars.copy_context() # Create worker management components - self.activity_tracker = ActivityTracker() - self.dynamic_scaler = DynamicScaler( + self._activity_tracker = ActivityTracker() + self._dynamic_scaler = DynamicScaler( min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS), max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS), scale_up_threshold=( @@ -233,15 +235,15 @@ class GraphEngine: else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME ), ) - self.worker_factory = WorkerFactory(flask_app, context_vars) + self._worker_factory = WorkerFactory(flask_app, context_vars) - self.worker_pool = WorkerPool( + self._worker_pool = WorkerPool( ready_queue=self.ready_queue, event_queue=self.event_queue, graph=self.graph, - worker_factory=self.worker_factory, - dynamic_scaler=self.dynamic_scaler, - activity_tracker=self.activity_tracker, + worker_factory=self._worker_factory, + dynamic_scaler=self._dynamic_scaler, + activity_tracker=self._activity_tracker, ) def _validate_graph_state_consistency(self) -> None: @@ -319,10 +321,10 @@ class GraphEngine: def _start_execution(self) -> None: """Start execution subsystems.""" # Calculate initial worker count - initial_workers = self.dynamic_scaler.calculate_initial_workers(self.graph) + initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph) # Start worker pool - self.worker_pool.start(initial_workers) + self._worker_pool.start(initial_workers) # Register response nodes for node in self.graph.nodes.values(): @@ -340,7 +342,7 @@ class GraphEngine: def _stop_execution(self) -> None: """Stop execution subsystems.""" self.dispatcher.stop() - self.worker_pool.stop() + self._worker_pool.stop() # Don't mark complete here as the dispatcher already does it # Notify layers diff --git a/api/core/workflow/graph_engine/graph_traversal/branch_handler.py b/api/core/workflow/graph_engine/graph_traversal/branch_handler.py index 685867a02d..b371f3bc73 100644 --- a/api/core/workflow/graph_engine/graph_traversal/branch_handler.py +++ b/api/core/workflow/graph_engine/graph_traversal/branch_handler.py @@ -2,15 +2,18 @@ Branch node handling for graph traversal. """ -from typing import Optional +from collections.abc import Sequence +from typing import final from core.workflow.graph import Graph +from core.workflow.graph_events.node import NodeRunStreamChunkEvent from ..state_management import EdgeStateManager from .edge_processor import EdgeProcessor from .skip_propagator import SkipPropagator +@final class BranchHandler: """ Handles branch node logic during graph traversal. @@ -40,7 +43,9 @@ class BranchHandler: self.skip_propagator = skip_propagator self.edge_state_manager = edge_state_manager - def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]: + def handle_branch_completion( + self, node_id: str, selected_handle: str | None + ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Handle completion of a branch node. @@ -58,10 +63,10 @@ class BranchHandler: raise ValueError(f"Branch node {node_id} completed without selecting a branch") # Categorize edges into selected and unselected - selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) + _, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) # Skip all unselected paths - self.skip_propagator.skip_branch_paths(node_id, unselected_edges) + self.skip_propagator.skip_branch_paths(unselected_edges) # Process selected edges and get ready nodes and streaming events return self.edge_processor.process_node_success(node_id, selected_handle) diff --git a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py index 79a7952282..ac2c658b4b 100644 --- a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py +++ b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py @@ -2,13 +2,18 @@ Edge processing logic for graph traversal. """ +from collections.abc import Sequence +from typing import final + from core.workflow.enums import NodeExecutionType from core.workflow.graph import Edge, Graph +from core.workflow.graph_events import NodeRunStreamChunkEvent from ..response_coordinator import ResponseStreamCoordinator from ..state_management import EdgeStateManager, NodeStateManager +@final class EdgeProcessor: """ Processes edges during graph execution. @@ -38,7 +43,9 @@ class EdgeProcessor: self.node_state_manager = node_state_manager self.response_coordinator = response_coordinator - def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]: + def process_node_success( + self, node_id: str, selected_handle: str | None = None + ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Process edges after a node succeeds. @@ -56,7 +63,7 @@ class EdgeProcessor: else: return self._process_non_branch_node_edges(node_id) - def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]: + def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Process edges for non-branch nodes (mark all as TAKEN). @@ -66,8 +73,8 @@ class EdgeProcessor: Returns: Tuple of (list of downstream nodes ready for execution, list of streaming events) """ - ready_nodes = [] - all_streaming_events = [] + ready_nodes: list[str] = [] + all_streaming_events: list[NodeRunStreamChunkEvent] = [] outgoing_edges = self.graph.get_outgoing_edges(node_id) for edge in outgoing_edges: @@ -77,7 +84,9 @@ class EdgeProcessor: return ready_nodes, all_streaming_events - def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]: + def _process_branch_node_edges( + self, node_id: str, selected_handle: str | None + ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Process edges for branch nodes. @@ -94,8 +103,8 @@ class EdgeProcessor: if not selected_handle: raise ValueError(f"Branch node {node_id} did not select any edge") - ready_nodes = [] - all_streaming_events = [] + ready_nodes: list[str] = [] + all_streaming_events: list[NodeRunStreamChunkEvent] = [] # Categorize edges selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle) @@ -112,7 +121,7 @@ class EdgeProcessor: return ready_nodes, all_streaming_events - def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]: + def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: """ Mark edge as taken and check downstream node. @@ -129,11 +138,11 @@ class EdgeProcessor: streaming_events = self.response_coordinator.on_edge_taken(edge.id) # Check if downstream node is ready - ready_nodes = [] + ready_nodes: list[str] = [] if self.node_state_manager.is_node_ready(edge.head): ready_nodes.append(edge.head) - return ready_nodes, list(streaming_events) + return ready_nodes, streaming_events def _process_skipped_edge(self, edge: Edge) -> None: """ diff --git a/api/core/workflow/graph_engine/graph_traversal/node_readiness.py b/api/core/workflow/graph_engine/graph_traversal/node_readiness.py index 93f9935a90..59bce3942c 100644 --- a/api/core/workflow/graph_engine/graph_traversal/node_readiness.py +++ b/api/core/workflow/graph_engine/graph_traversal/node_readiness.py @@ -2,10 +2,13 @@ Node readiness checking for execution. """ +from typing import final + from core.workflow.enums import NodeState from core.workflow.graph import Graph +@final class NodeReadinessChecker: """ Checks if nodes are ready for execution based on their dependencies. @@ -71,7 +74,7 @@ class NodeReadinessChecker: Returns: List of node IDs that are now ready """ - ready_nodes = [] + ready_nodes: list[str] = [] outgoing_edges = self.graph.get_outgoing_edges(from_node_id) for edge in outgoing_edges: diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py index ef0e5e3273..5ac445d405 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py @@ -2,11 +2,15 @@ Skip state propagation through the graph. """ -from core.workflow.graph import Graph +from collections.abc import Sequence +from typing import final + +from core.workflow.graph import Edge, Graph from ..state_management import EdgeStateManager, NodeStateManager +@final class SkipPropagator: """ Propagates skip states through the graph. @@ -57,9 +61,8 @@ class SkipPropagator: # If any edge is taken, node may still execute if edge_states["has_taken"]: - # Check if node is ready and enqueue if so - if self.node_state_manager.is_node_ready(downstream_node_id): - self.node_state_manager.enqueue_node(downstream_node_id) + # Enqueue node + self.node_state_manager.enqueue_node(downstream_node_id) return # All edges are skipped, propagate skip to this node @@ -83,12 +86,11 @@ class SkipPropagator: # Recursively propagate skip self.propagate_skip_from_edge(edge.id) - def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None: + def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None: """ Skip all paths from unselected branch edges. Args: - node_id: The ID of the branch node unselected_edges: List of edges not taken by the branch """ for edge in unselected_edges: diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/core/workflow/graph_engine/layers/base.py index df8115c526..febdc3de6d 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/core/workflow/graph_engine/layers/base.py @@ -6,7 +6,6 @@ intercept and respond to GraphEngine events. """ from abc import ABC, abstractmethod -from typing import Optional from core.workflow.entities import GraphRuntimeState from core.workflow.graph_engine.protocols.command_channel import CommandChannel @@ -28,8 +27,8 @@ class Layer(ABC): def __init__(self) -> None: """Initialize the layer. Subclasses can override with custom parameters.""" - self.graph_runtime_state: Optional[GraphRuntimeState] = None - self.command_channel: Optional[CommandChannel] = None + self.graph_runtime_state: GraphRuntimeState | None = None + self.command_channel: CommandChannel | None = None def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None: """ @@ -73,7 +72,7 @@ class Layer(ABC): pass @abstractmethod - def on_graph_end(self, error: Optional[Exception]) -> None: + def on_graph_end(self, error: Exception | None) -> None: """ Called when graph execution ends. diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/core/workflow/graph_engine/layers/debug_logging.py index b5222c51d3..3052600161 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/core/workflow/graph_engine/layers/debug_logging.py @@ -7,7 +7,7 @@ graph execution for debugging purposes. import logging from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, final from core.workflow.graph_events import ( GraphEngineEvent, @@ -34,6 +34,7 @@ from core.workflow.graph_events import ( from .base import Layer +@final class DebugLoggingLayer(Layer): """ A layer that provides comprehensive logging of GraphEngine execution. @@ -221,7 +222,7 @@ class DebugLoggingLayer(Layer): # Log unknown events at debug level self.logger.debug("Event: %s", event_class) - def on_graph_end(self, error: Optional[Exception]) -> None: + def on_graph_end(self, error: Exception | None) -> None: """Log graph execution end with summary statistics.""" self.logger.info("=" * 80) diff --git a/api/core/workflow/graph_engine/layers/execution_limits.py b/api/core/workflow/graph_engine/layers/execution_limits.py index 321a7df8c3..efda0bacbe 100644 --- a/api/core/workflow/graph_engine/layers/execution_limits.py +++ b/api/core/workflow/graph_engine/layers/execution_limits.py @@ -11,7 +11,7 @@ When limits are exceeded, the layer automatically aborts execution. import logging import time from enum import Enum -from typing import Optional +from typing import final from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType from core.workflow.graph_engine.layers import Layer @@ -29,6 +29,7 @@ class LimitType(Enum): TIME_LIMIT = "time_limit" +@final class ExecutionLimitsLayer(Layer): """ Layer that enforces execution limits for workflows. @@ -53,7 +54,7 @@ class ExecutionLimitsLayer(Layer): self.max_time = max_time # Runtime tracking - self.start_time: Optional[float] = None + self.start_time: float | None = None self.step_count = 0 self.logger = logging.getLogger(__name__) @@ -94,7 +95,7 @@ class ExecutionLimitsLayer(Layer): if self._reached_time_limitation(): self._send_abort_command(LimitType.TIME_LIMIT) - def on_graph_end(self, error: Optional[Exception]) -> None: + def on_graph_end(self, error: Exception | None) -> None: """Called when graph execution ends.""" if self._execution_started and not self._execution_ended: self._execution_ended = True diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index a4f9cc7192..ed62209acb 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -6,13 +6,14 @@ using the new Redis command channel, without requiring user permission checks. Supports stop, pause, and resume operations. """ -from typing import Optional +from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.entities.commands import AbortCommand from extensions.ext_redis import redis_client +@final class GraphEngineManager: """ Manager for sending control commands to GraphEngine instances. @@ -23,7 +24,7 @@ class GraphEngineManager: """ @staticmethod - def send_stop_command(task_id: str, reason: Optional[str] = None) -> None: + def send_stop_command(task_id: str, reason: str | None = None) -> None: """ Send a stop command to a running workflow. diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 7fc441f194..694355298c 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -6,7 +6,9 @@ import logging import queue import threading import time -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, final + +from core.workflow.graph_events.base import GraphNodeEventBase from ..event_management import EventCollector, EventEmitter from .execution_coordinator import ExecutionCoordinator @@ -17,6 +19,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@final class Dispatcher: """ Main dispatcher that processes events from the event queue. @@ -27,12 +30,12 @@ class Dispatcher: def __init__( self, - event_queue: queue.Queue, + event_queue: queue.Queue[GraphNodeEventBase], event_handler: "EventHandlerRegistry", event_collector: EventCollector, execution_coordinator: ExecutionCoordinator, max_execution_time: int, - event_emitter: Optional[EventEmitter] = None, + event_emitter: EventEmitter | None = None, ) -> None: """ Initialize the dispatcher. @@ -52,9 +55,9 @@ class Dispatcher: self.max_execution_time = max_execution_time self.event_emitter = event_emitter - self._thread: Optional[threading.Thread] = None + self._thread: threading.Thread | None = None self._stop_event = threading.Event() - self._start_time: Optional[float] = None + self._start_time: float | None = None def start(self) -> None: """Start the dispatcher thread.""" diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py index 899cb6a0d5..5f95b5b29e 100644 --- a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py +++ b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py @@ -2,7 +2,7 @@ Execution coordinator for managing overall workflow execution. """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, final from ..command_processing import CommandProcessor from ..domain import GraphExecution @@ -14,6 +14,7 @@ if TYPE_CHECKING: from ..event_management import EventHandlerRegistry +@final class ExecutionCoordinator: """ Coordinates overall execution flow between subsystems. diff --git a/api/core/workflow/graph_engine/output_registry/registry.py b/api/core/workflow/graph_engine/output_registry/registry.py index 0f3e690eb1..4df7da207c 100644 --- a/api/core/workflow/graph_engine/output_registry/registry.py +++ b/api/core/workflow/graph_engine/output_registry/registry.py @@ -7,7 +7,7 @@ thread-safe storage for node outputs. from collections.abc import Sequence from threading import RLock -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Union, final from core.variables import Segment from core.workflow.entities.variable_pool import VariablePool @@ -18,6 +18,7 @@ if TYPE_CHECKING: from core.workflow.graph_events import NodeRunStreamChunkEvent +@final class OutputRegistry: """ Thread-safe registry for storing and retrieving node outputs. @@ -47,7 +48,7 @@ class OutputRegistry: with self._lock: self._scalars.add(selector, value) - def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]: + def get_scalar(self, selector: Sequence[str]) -> "Segment | None": """ Get a scalar value for the given selector. @@ -81,7 +82,7 @@ class OutputRegistry: except ValueError: raise ValueError(f"Stream {'.'.join(selector)} is already closed") - def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]: + def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None": """ Pop the next unread NodeRunStreamChunkEvent from the stream. diff --git a/api/core/workflow/graph_engine/output_registry/stream.py b/api/core/workflow/graph_engine/output_registry/stream.py index dc12e479a4..8a99b56d1f 100644 --- a/api/core/workflow/graph_engine/output_registry/stream.py +++ b/api/core/workflow/graph_engine/output_registry/stream.py @@ -5,12 +5,13 @@ This module contains the private Stream class used internally by OutputRegistry to manage streaming data chunks. """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, final if TYPE_CHECKING: from core.workflow.graph_events import NodeRunStreamChunkEvent +@final class Stream: """ A stream that holds NodeRunStreamChunkEvent objects and tracks read position. @@ -41,7 +42,7 @@ class Stream: raise ValueError("Cannot append to a closed stream") self.events.append(event) - def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]: + def pop_next(self) -> "NodeRunStreamChunkEvent | None": """ Pop the next unread NodeRunStreamChunkEvent from the stream. diff --git a/api/core/workflow/graph_engine/error_handling/error_strategy.py b/api/core/workflow/graph_engine/protocols/error_strategy.py similarity index 88% rename from api/core/workflow/graph_engine/error_handling/error_strategy.py rename to api/core/workflow/graph_engine/protocols/error_strategy.py index 0d3c662888..bf8b316423 100644 --- a/api/core/workflow/graph_engine/error_handling/error_strategy.py +++ b/api/core/workflow/graph_engine/protocols/error_strategy.py @@ -2,7 +2,7 @@ Base error strategy protocol. """ -from typing import Optional, Protocol +from typing import Protocol from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent @@ -16,7 +16,7 @@ class ErrorStrategy(Protocol): node execution failures. """ - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]: + def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: """ Handle a node failure event. diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 40c7d19102..4c3cc167fa 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -9,7 +9,7 @@ import logging from collections import deque from collections.abc import Sequence from threading import RLock -from typing import Optional, TypeAlias +from typing import TypeAlias, final from uuid import uuid4 from core.workflow.enums import NodeExecutionType, NodeState @@ -28,6 +28,7 @@ NodeID: TypeAlias = str EdgeID: TypeAlias = str +@final class ResponseStreamCoordinator: """ Manages response streaming sessions without relying on global state. @@ -45,7 +46,7 @@ class ResponseStreamCoordinator: """ self.registry = registry self.graph = graph - self.active_session: Optional[ResponseSession] = None + self.active_session: ResponseSession | None = None self.waiting_sessions: deque[ResponseSession] = deque() self.lock = RLock() diff --git a/api/core/workflow/graph_engine/state_management/edge_state_manager.py b/api/core/workflow/graph_engine/state_management/edge_state_manager.py index 9e238a6fdd..747062284a 100644 --- a/api/core/workflow/graph_engine/state_management/edge_state_manager.py +++ b/api/core/workflow/graph_engine/state_management/edge_state_manager.py @@ -3,7 +3,8 @@ Manager for edge states during graph execution. """ import threading -from typing import TypedDict +from collections.abc import Sequence +from typing import TypedDict, final from core.workflow.enums import NodeState from core.workflow.graph import Edge, Graph @@ -17,6 +18,7 @@ class EdgeStateAnalysis(TypedDict): all_skipped: bool +@final class EdgeStateManager: """ Manages edge states and transitions during graph execution. @@ -87,7 +89,7 @@ class EdgeStateManager: with self._lock: return self.graph.edges[edge_id].state - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]: + def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: """ Categorize branch edges into selected and unselected. @@ -100,8 +102,8 @@ class EdgeStateManager: """ with self._lock: outgoing_edges = self.graph.get_outgoing_edges(node_id) - selected_edges = [] - unselected_edges = [] + selected_edges: list[Edge] = [] + unselected_edges: list[Edge] = [] for edge in outgoing_edges: if edge.source_handle == selected_handle: diff --git a/api/core/workflow/graph_engine/state_management/execution_tracker.py b/api/core/workflow/graph_engine/state_management/execution_tracker.py index 2008f30777..01fa80f2ce 100644 --- a/api/core/workflow/graph_engine/state_management/execution_tracker.py +++ b/api/core/workflow/graph_engine/state_management/execution_tracker.py @@ -3,8 +3,10 @@ Tracker for currently executing nodes. """ import threading +from typing import final +@final class ExecutionTracker: """ Tracks nodes that are currently being executed. diff --git a/api/core/workflow/graph_engine/state_management/node_state_manager.py b/api/core/workflow/graph_engine/state_management/node_state_manager.py index 61bb639cda..d5ed42ad1d 100644 --- a/api/core/workflow/graph_engine/state_management/node_state_manager.py +++ b/api/core/workflow/graph_engine/state_management/node_state_manager.py @@ -4,11 +4,13 @@ Manager for node states during graph execution. import queue import threading +from typing import final from core.workflow.enums import NodeState from core.workflow.graph import Graph +@final class NodeStateManager: """ Manages node states and the ready queue for execution. diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index bc4025978a..dacf6f0435 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -11,7 +11,7 @@ import threading import time from collections.abc import Callable from datetime import datetime -from typing import Optional +from typing import final from uuid import uuid4 from flask import Flask @@ -23,6 +23,7 @@ from core.workflow.nodes.base.node import Node from libs.flask_utils import preserve_flask_contexts +@final class Worker(threading.Thread): """ Worker thread that executes nodes from the ready queue. @@ -38,10 +39,10 @@ class Worker(threading.Thread): event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, worker_id: int = 0, - flask_app: Optional[Flask] = None, - context_vars: Optional[contextvars.Context] = None, - on_idle_callback: Optional[Callable[[int], None]] = None, - on_active_callback: Optional[Callable[[int], None]] = None, + flask_app: Flask | None = None, + context_vars: contextvars.Context | None = None, + on_idle_callback: Callable[[int], None] | None = None, + on_active_callback: Callable[[int], None] | None = None, ) -> None: """ Initialize worker thread. diff --git a/api/core/workflow/graph_engine/worker_management/activity_tracker.py b/api/core/workflow/graph_engine/worker_management/activity_tracker.py index 5203fc6b6c..b2125a0158 100644 --- a/api/core/workflow/graph_engine/worker_management/activity_tracker.py +++ b/api/core/workflow/graph_engine/worker_management/activity_tracker.py @@ -4,8 +4,10 @@ Activity tracker for monitoring worker activity. import threading import time +from typing import final +@final class ActivityTracker: """ Tracks worker activity for scaling decisions. diff --git a/api/core/workflow/graph_engine/worker_management/dynamic_scaler.py b/api/core/workflow/graph_engine/worker_management/dynamic_scaler.py index 7a1920a724..7450b02618 100644 --- a/api/core/workflow/graph_engine/worker_management/dynamic_scaler.py +++ b/api/core/workflow/graph_engine/worker_management/dynamic_scaler.py @@ -2,9 +2,12 @@ Dynamic scaler for worker pool sizing. """ +from typing import final + from core.workflow.graph import Graph +@final class DynamicScaler: """ Manages dynamic scaling decisions for the worker pool. diff --git a/api/core/workflow/graph_engine/worker_management/worker_factory.py b/api/core/workflow/graph_engine/worker_management/worker_factory.py index 76cfc45b10..673ca11f26 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_factory.py +++ b/api/core/workflow/graph_engine/worker_management/worker_factory.py @@ -5,7 +5,7 @@ Factory for creating worker instances. import contextvars import queue from collections.abc import Callable -from typing import Optional +from typing import final from flask import Flask @@ -14,6 +14,7 @@ from core.workflow.graph import Graph from ..worker import Worker +@final class WorkerFactory: """ Factory for creating worker instances with proper context. @@ -24,7 +25,7 @@ class WorkerFactory: def __init__( self, - flask_app: Optional[Flask], + flask_app: Flask | None, context_vars: contextvars.Context, ) -> None: """ @@ -43,8 +44,8 @@ class WorkerFactory: ready_queue: queue.Queue[str], event_queue: queue.Queue, graph: Graph, - on_idle_callback: Optional[Callable[[int], None]] = None, - on_active_callback: Optional[Callable[[int], None]] = None, + on_idle_callback: Callable[[int], None] | None = None, + on_active_callback: Callable[[int], None] | None = None, ) -> Worker: """ Create a new worker instance. diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 8faa9da156..55250809cd 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -4,6 +4,7 @@ Worker pool management. import queue import threading +from typing import final from core.workflow.graph import Graph @@ -13,6 +14,7 @@ from .dynamic_scaler import DynamicScaler from .worker_factory import WorkerFactory +@final class WorkerPool: """ Manages a pool of worker threads for executing nodes. diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 905cb49be2..2331c65de8 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Any, Optional from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node @@ -11,6 +11,7 @@ from core.workflow.nodes.start.entities import StartNodeData class StartNode(Node): node_type = NodeType.START + execution_type = NodeExecutionType.ROOT _node_data: StartNodeData diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index d13393dd14..2960cde242 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -65,7 +65,7 @@ class Storage: from extensions.storage.volcengine_tos_storage import VolcengineTosStorage return VolcengineTosStorage - case StorageType.SUPBASE: + case StorageType.SUPABASE: from extensions.storage.supabase_storage import SupabaseStorage return SupabaseStorage diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index bc2d632159..baffa423b6 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -14,4 +14,4 @@ class StorageType(StrEnum): S3 = "s3" TENCENT_COS = "tencent-cos" VOLCENGINE_TOS = "volcengine-tos" - SUPBASE = "supabase" + SUPABASE = "supabase" diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 2ccab55dae..2104e66254 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -137,10 +137,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen return cast(Variable, result) -def infer_segment_type_from_value(value: Any, /) -> SegmentType: - return build_segment(value).value_type - - def build_segment(value: Any, /) -> Segment: # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` # below diff --git a/api/libs/helper.py b/api/libs/helper.py index d4f15ca937..96e8524660 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -301,8 +301,8 @@ class TokenManager: if expiry_minutes is None: raise ValueError(f"Expiry minutes for {token_type} token is not set") token_key = cls._get_token_key(token, token_type) - expiry_time = int(expiry_minutes * 60) - redis_client.setex(token_key, expiry_time, json.dumps(token_data)) + expiry_seconds = int(expiry_minutes * 60) + redis_client.setex(token_key, expiry_seconds, json.dumps(token_data)) if account_id: cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes) @@ -336,11 +336,11 @@ class TokenManager: @classmethod def _set_current_token_for_account( - cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] + cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float] ): key = cls._get_account_token_key(account_id, token_type) - expiry_time = int(expiry_hours * 60 * 60) - redis_client.setex(key, expiry_time, token) + expiry_seconds = int(expiry_minutes * 60) + redis_client.setex(key, expiry_seconds, token) @classmethod def _get_account_token_key(cls, account_id: str, token_type: str) -> str: diff --git a/api/models/model.py b/api/models/model.py index eea488647e..47bc98a148 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text +from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -1556,7 +1556,7 @@ class ApiToken(Base): def generate_api_key(prefix, n): while True: result = prefix + generate_string(n) - if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0: + if db.session.scalar(select(exists().where(ApiToken.token == result))): continue return result diff --git a/api/models/workflow.py b/api/models/workflow.py index e0c02d7cd4..79d476e69a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, orm +from sqlalchemy import DateTime, exists, orm, select from core.file.constants import maybe_file_object from core.file.models import File @@ -348,12 +348,13 @@ class Workflow(Base): """ from models.tools import WorkflowToolProvider - return ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) - .count() - > 0 + stmt = select( + exists().where( + WorkflowToolProvider.tenant_id == self.tenant_id, + WorkflowToolProvider.app_id == self.app_id, + ) ) + return db.session.execute(stmt).scalar_one() @property def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: @@ -952,7 +953,7 @@ def _naive_utc_datetime(): class WorkflowDraftVariable(Base): """`WorkflowDraftVariable` record variables and outputs generated during - debugging worfklow or chatflow. + debugging workflow or chatflow. IMPORTANT: This model maintains multiple invariant rules that must be preserved. Do not instantiate this class directly with the constructor. diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 46b2c61800..3405151c66 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,7 +9,7 @@ from collections import Counter from typing import Any, Literal, Optional from flask_login import current_user -from sqlalchemy import func, select +from sqlalchemy import exists, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -845,10 +845,8 @@ class DatasetService: @staticmethod def dataset_use_check(dataset_id) -> bool: - count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count() - if count > 0: - return True - return False + stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id)) + return db.session.execute(stmt).scalar_one() @staticmethod def check_dataset_permission(dataset, user): diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index c2d730fccf..5a6a24e57d 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from pathlib import Path from typing import Any, Optional +from sqlalchemy import exists, select from sqlalchemy.orm import Session from configs import dify_config @@ -190,11 +191,14 @@ class BuiltinToolManageService: # update name if provided if name and name != db_provider.name: # check if the name is already used - if ( - session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, provider=provider, name=name) - .count() - > 0 + if session.scalar( + select( + exists().where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.name == name, + ) + ) ): raise ValueError(f"the credential name '{name}' is already used") @@ -246,11 +250,14 @@ class BuiltinToolManageService: ) else: # check if the name is already used - if ( - session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, provider=provider, name=name) - .count() - > 0 + if session.scalar( + select( + exists().where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.name == name, + ) + ) ): raise ValueError(f"the credential name '{name}' is already used") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index baefab3454..2075a4cbec 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -4,7 +4,7 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, Optional, cast -from sqlalchemy import select +from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType @@ -83,15 +83,14 @@ class WorkflowService: ) def is_workflow_exist(self, app_model: App) -> bool: - return ( - db.session.query(Workflow) - .where( + stmt = select( + exists().where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == Workflow.VERSION_DRAFT, ) - .count() - ) > 0 + ) + return db.session.execute(stmt).scalar_one() def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: """ diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index c824059bf0..c0020b29ed 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import exists, select from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db @@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): start_at = time.perf_counter() # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count() + annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) if not app: logger.info(click.style(f"App not found: {app_id}", fg="red")) db.session.close() @@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): ) try: - if annotations_count > 0: + if annotations_exists: vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.delete() except Exception: diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py index 20f753786d..57ddacd13d 100644 --- a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -39,7 +39,7 @@ def test_page_result(text, cursor, maxlen, expected): # Tests: get_url # --------------------------- @pytest.fixture -def stub_support_types(monkeypatch): +def stub_support_types(monkeypatch: pytest.MonkeyPatch): """Stub supported content types list.""" import core.tools.utils.web_reader_tool as mod @@ -48,7 +48,7 @@ def stub_support_types(monkeypatch): return mod -def test_get_url_unsupported_content_type(monkeypatch, stub_support_types): +def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types): # HEAD 200 but content-type not supported and not text/html def fake_head(url, headers=None, follow_redirects=True, timeout=None): return FakeResponse( @@ -62,7 +62,7 @@ def test_get_url_unsupported_content_type(monkeypatch, stub_support_types): assert result == "Unsupported content-type [image/png] of URL." -def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types): +def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytest.MonkeyPatch, stub_support_types): """ When content-type is in SUPPORT_URL_CONTENT_TYPES, should call ExtractProcessor.load_from_url and return its text. @@ -88,7 +88,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_ assert result == "PDF extracted text" -def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types): +def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.MonkeyPatch, stub_support_types): """200 + text/html → GET, chardet detects encoding, readability returns article which is templated.""" def fake_head(url, headers=None, follow_redirects=True, timeout=None): @@ -121,7 +121,7 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_suppor assert "Hello world" in out -def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_support_types): +def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.MonkeyPatch, stub_support_types): """If readability returns no text, should return empty string.""" def fake_head(url, headers=None, follow_redirects=True, timeout=None): @@ -142,7 +142,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_su assert out == "" -def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types): +def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types): """HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed.""" def fake_head(url, headers=None, follow_redirects=True, timeout=None): @@ -175,7 +175,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types): assert "X" in out -def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types): +def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, stub_support_types): """HEAD returns non-200 and non-403 → should directly return code message.""" def fake_head(url, headers=None, follow_redirects=True, timeout=None): @@ -189,7 +189,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types): assert out == "URL returned status code 500." -def test_get_url_content_disposition_filename_detection(monkeypatch, stub_support_types): +def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.MonkeyPatch, stub_support_types): """ If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type, it should route to ExtractProcessor.load_from_url. @@ -213,7 +213,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch, stub_suppor assert out == "From ExtractProcessor via filename" -def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_support_types): +def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.MonkeyPatch, stub_support_types): """ If chardet returns an encoding but content.decode raises, should fallback to response.text. """ @@ -250,7 +250,7 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_supp # --------------------------- -def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch): +def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch: pytest.MonkeyPatch): # stub readabilipy.simple_json_from_html_string def fake_simple_json_from_html_string(html, use_readability=True): return { @@ -271,7 +271,7 @@ def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch): assert article.text[0]["text"] == "world" -def test_extract_using_readabilipy_defaults_when_missing(monkeypatch): +def test_extract_using_readabilipy_defaults_when_missing(monkeypatch: pytest.MonkeyPatch): def fake_simple_json_from_html_string(html, use_readability=True): return {} # all missing diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index fa6fc3ba32..5348f729f9 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -8,7 +8,7 @@ from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch): +def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when `WorkflowAppGenerator.generate` returns a result with `error` key inside the `data` element. @@ -40,7 +40,7 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", lambda *args, **kwargs: {"data": {"error": "oops"}}, ) - monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) with pytest.raises(ToolInvokeError) as exc_info: # WorkflowTool always returns a generator, so we need to iterate to diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py new file mode 100644 index 0000000000..01b514ed7c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph/test_graph.py @@ -0,0 +1,281 @@ +"""Unit tests for Graph class methods.""" + +from unittest.mock import Mock + +from core.workflow.enums import NodeExecutionType, NodeState, NodeType +from core.workflow.graph.edge import Edge +from core.workflow.graph.graph import Graph +from core.workflow.nodes.base.node import Node + + +def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: + """Create a mock node for testing.""" + node = Mock(spec=Node) + node.id = node_id + node.execution_type = execution_type + node.state = state + node.node_type = NodeType.START + return node + + +class TestMarkInactiveRootBranches: + """Test cases for _mark_inactive_root_branches method.""" + + def test_single_root_no_marking(self): + """Test that single root graph doesn't mark anything as skipped.""" + nodes = { + "root1": create_mock_node("root1", NodeExecutionType.ROOT), + "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), + } + + edges = { + "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), + } + + in_edges = {"child1": ["edge1"]} + out_edges = {"root1": ["edge1"]} + + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") + + assert nodes["root1"].state == NodeState.UNKNOWN + assert nodes["child1"].state == NodeState.UNKNOWN + assert edges["edge1"].state == NodeState.UNKNOWN + + def test_multiple_roots_mark_inactive(self): + """Test marking inactive root branches with multiple root nodes.""" + nodes = { + "root1": create_mock_node("root1", NodeExecutionType.ROOT), + "root2": create_mock_node("root2", NodeExecutionType.ROOT), + "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), + "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), + } + + edges = { + "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), + "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), + } + + in_edges = {"child1": ["edge1"], "child2": ["edge2"]} + out_edges = {"root1": ["edge1"], "root2": ["edge2"]} + + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") + + assert nodes["root1"].state == NodeState.UNKNOWN + assert nodes["root2"].state == NodeState.SKIPPED + assert nodes["child1"].state == NodeState.UNKNOWN + assert nodes["child2"].state == NodeState.SKIPPED + assert edges["edge1"].state == NodeState.UNKNOWN + assert edges["edge2"].state == NodeState.SKIPPED + + def test_shared_downstream_node(self): + """Test that shared downstream nodes are not skipped if at least one path is active.""" + nodes = { + "root1": create_mock_node("root1", NodeExecutionType.ROOT), + "root2": create_mock_node("root2", NodeExecutionType.ROOT), + "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), + "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), + "shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE), + } + + edges = { + "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), + "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), + "edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"), + "edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"), + } + + in_edges = { + "child1": ["edge1"], + "child2": ["edge2"], + "shared": ["edge3", "edge4"], + } + out_edges = { + "root1": ["edge1"], + "root2": ["edge2"], + "child1": ["edge3"], + "child2": ["edge4"], + } + + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") + + assert nodes["root1"].state == NodeState.UNKNOWN + assert nodes["root2"].state == NodeState.SKIPPED + assert nodes["child1"].state == NodeState.UNKNOWN + assert nodes["child2"].state == NodeState.SKIPPED + assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active + assert edges["edge1"].state == NodeState.UNKNOWN + assert edges["edge2"].state == NodeState.SKIPPED + assert edges["edge3"].state == NodeState.UNKNOWN + assert edges["edge4"].state == NodeState.SKIPPED + + def test_deep_branch_marking(self): + """Test marking deep branches with multiple levels.""" + nodes = { + "root1": create_mock_node("root1", NodeExecutionType.ROOT), + "root2": create_mock_node("root2", NodeExecutionType.ROOT), + "level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE), + "level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE), + "level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE), + "level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE), + "level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE), + } + + edges = { + "edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"), + "edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"), + "edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"), + "edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"), + "edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"), + } + + in_edges = { + "level1_a": ["edge1"], + "level1_b": ["edge2"], + "level2_a": ["edge3"], + "level2_b": ["edge4"], + "level3": ["edge5"], + } + out_edges = { + "root1": ["edge1"], + "root2": ["edge2"], + "level1_a": ["edge3"], + "level1_b": ["edge4"], + "level2_b": ["edge5"], + } + + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") + + assert nodes["root1"].state == NodeState.UNKNOWN + assert nodes["root2"].state == NodeState.SKIPPED + assert nodes["level1_a"].state == NodeState.UNKNOWN + assert nodes["level1_b"].state == NodeState.SKIPPED + assert nodes["level2_a"].state == NodeState.UNKNOWN + assert nodes["level2_b"].state == NodeState.SKIPPED + assert nodes["level3"].state == NodeState.SKIPPED + assert edges["edge1"].state == NodeState.UNKNOWN + assert edges["edge2"].state == NodeState.SKIPPED + assert edges["edge3"].state == NodeState.UNKNOWN + assert edges["edge4"].state == NodeState.SKIPPED + assert edges["edge5"].state == NodeState.SKIPPED + + def test_non_root_execution_type(self): + """Test that nodes with non-ROOT execution type are not treated as root nodes.""" + nodes = { + "root1": create_mock_node("root1", NodeExecutionType.ROOT), + "non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE), + "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), + "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), + } + + edges = { + "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), + "edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"), + } + + in_edges = {"child1": ["edge1"], "child2": ["edge2"]} + out_edges = {"root1": ["edge1"], "non_root": ["edge2"]} + + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") + + assert nodes["root1"].state == NodeState.UNKNOWN + assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped + assert nodes["child1"].state == NodeState.UNKNOWN + assert nodes["child2"].state == NodeState.UNKNOWN + assert edges["edge1"].state == NodeState.UNKNOWN + assert edges["edge2"].state == NodeState.UNKNOWN + + def test_empty_graph(self): + """Test handling of empty graph structures.""" + nodes = {} + edges = {} + in_edges = {} + out_edges = {} + + # Should not raise any errors + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent") + + def test_three_roots_mark_two_inactive(self): + """Test with three root nodes where two should be marked inactive.""" + nodes = { + "root1": create_mock_node("root1", NodeExecutionType.ROOT), + "root2": create_mock_node("root2", NodeExecutionType.ROOT), + "root3": create_mock_node("root3", NodeExecutionType.ROOT), + "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), + "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), + "child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE), + } + + edges = { + "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), + "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), + "edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"), + } + + in_edges = { + "child1": ["edge1"], + "child2": ["edge2"], + "child3": ["edge3"], + } + out_edges = { + "root1": ["edge1"], + "root2": ["edge2"], + "root3": ["edge3"], + } + + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2") + + assert nodes["root1"].state == NodeState.SKIPPED + assert nodes["root2"].state == NodeState.UNKNOWN # Active root + assert nodes["root3"].state == NodeState.SKIPPED + assert nodes["child1"].state == NodeState.SKIPPED + assert nodes["child2"].state == NodeState.UNKNOWN + assert nodes["child3"].state == NodeState.SKIPPED + assert edges["edge1"].state == NodeState.SKIPPED + assert edges["edge2"].state == NodeState.UNKNOWN + assert edges["edge3"].state == NodeState.SKIPPED + + def test_convergent_paths(self): + """Test convergent paths where multiple inactive branches lead to same node.""" + nodes = { + "root1": create_mock_node("root1", NodeExecutionType.ROOT), + "root2": create_mock_node("root2", NodeExecutionType.ROOT), + "root3": create_mock_node("root3", NodeExecutionType.ROOT), + "mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE), + "mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE), + "convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE), + } + + edges = { + "edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"), + "edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"), + "edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"), + "edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"), + "edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"), + } + + in_edges = { + "mid1": ["edge1"], + "mid2": ["edge2"], + "convergent": ["edge3", "edge4", "edge5"], + } + out_edges = { + "root1": ["edge1"], + "root2": ["edge2"], + "root3": ["edge3"], + "mid1": ["edge4"], + "mid2": ["edge5"], + } + + Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") + + assert nodes["root1"].state == NodeState.UNKNOWN + assert nodes["root2"].state == NodeState.SKIPPED + assert nodes["root3"].state == NodeState.SKIPPED + assert nodes["mid1"].state == NodeState.UNKNOWN + assert nodes["mid2"].state == NodeState.SKIPPED + assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1 + assert edges["edge1"].state == NodeState.UNKNOWN + assert edges["edge2"].state == NodeState.SKIPPED + assert edges["edge3"].state == NodeState.SKIPPED + assert edges["edge4"].state == NodeState.UNKNOWN + assert edges["edge5"].state == NodeState.SKIPPED diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py index 61f6fb1af4..fc38393e75 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py @@ -21,7 +21,6 @@ from .test_mock_config import MockConfigBuilder from .test_table_runner import TableTestRunner, WorkflowTestCase -@pytest.mark.skip class TestComplexBranchWorkflow: """Test suite for complex branch workflow with parallel execution.""" @@ -30,6 +29,7 @@ class TestComplexBranchWorkflow: self.runner = TableTestRunner() self.fixture_path = "test_complex_branch" + @pytest.mark.skip(reason="output in this workflow can be random") def test_hello_branch_with_llm(self): """ Test when query contains 'hello' - should trigger true branch. diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 3da0601e70..c6e5f72888 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,7 +12,7 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from pathlib import Path @@ -34,7 +34,11 @@ from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.system_variable import SystemVariable from models.enums import UserFrom @@ -57,7 +61,7 @@ class WorkflowTestCase: timeout: float = 30.0 mock_config: Optional[MockConfig] = None use_auto_mock: bool = False - expected_event_sequence: Optional[list[type[GraphEngineEvent]]] = None + expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None tags: list[str] = field(default_factory=list) skip: bool = False skip_reason: str = "" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py index 2d26931f18..221e1291d1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py @@ -9,13 +9,6 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ from .test_table_runner import TableTestRunner, WorkflowTestCase -def mock_template_transform_run(self): - """Mock the TemplateTransformNode._run() method to return results based on node title.""" - title = self._node_data.title - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) - - -@pytest.mark.skip class TestVariableAggregator: """Test cases for the variable aggregator workflow.""" @@ -37,6 +30,12 @@ class TestVariableAggregator: description: str, ) -> None: """Test all four combinations of switch1 and switch2.""" + + def mock_template_transform_run(self): + """Mock the TemplateTransformNode._run() method to return results based on node title.""" + title = self._node_data.title + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) + with patch.object( TemplateTransformNode, "_run", diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py deleted file mode 100644 index d632c336c5..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ /dev/null @@ -1,353 +0,0 @@ -import httpx -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.file import File, FileTransferMethod, FileType -from core.variables import ArrayFileVariable, FileVariable -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute -from core.workflow.nodes.end.entities import EndStreamParam -from core.workflow.nodes.http_request import ( - BodyData, - HttpRequestNode, - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeData, -) -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom - - -@pytest.mark.skip( - reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - " - "needs rewrite for new architecture" -) -def test_http_request_node_binary_file(monkeypatch): - data = HttpRequestNodeData( - title="test", - method="post", - url="http://example.org/post", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="", - params="", - body=HttpRequestNodeBody( - type="binary", - data=[ - BodyData( - key="file", - type="file", - value="", - file=["1111", "file"], - ) - ], - ), - ) - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - ) - variable_pool.add( - ["1111", "file"], - FileVariable( - name="file", - value=File( - tenant_id="1", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="1111", - storage_key="", - ), - ), - ) - - node_config = { - "id": "1", - "data": data.model_dump(), - } - - node = HttpRequestNode( - id="1", - config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), - ) - - # Initialize node data - node.init_node_data(node_config["data"]) - monkeypatch.setattr( - "core.workflow.nodes.http_request.executor.file_manager.download", - lambda *args, **kwargs: b"test", - ) - monkeypatch.setattr( - "core.helper.ssrf_proxy.post", - lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]), - ) - result = node._run() - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs is not None - assert result.outputs["body"] == "test" - - -@pytest.mark.skip( - reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - " - "needs rewrite for new architecture" -) -def test_http_request_node_form_with_file(monkeypatch): - data = HttpRequestNodeData( - title="test", - method="post", - url="http://example.org/post", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="", - params="", - body=HttpRequestNodeBody( - type="form-data", - data=[ - BodyData( - key="file", - type="file", - file=["1111", "file"], - ), - BodyData( - key="name", - type="text", - value="test", - ), - ], - ), - ) - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - ) - variable_pool.add( - ["1111", "file"], - FileVariable( - name="file", - value=File( - tenant_id="1", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="1111", - storage_key="", - ), - ), - ) - - node_config = { - "id": "1", - "data": data.model_dump(), - } - - node = HttpRequestNode( - id="1", - config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), - ) - - # Initialize node data - node.init_node_data(node_config["data"]) - - monkeypatch.setattr( - "core.workflow.nodes.http_request.executor.file_manager.download", - lambda *args, **kwargs: b"test", - ) - - def attr_checker(*args, **kwargs): - assert kwargs["data"] == {"name": "test"} - assert kwargs["files"] == [("file", (None, b"test", "application/octet-stream"))] - return httpx.Response(200, content=b"") - - monkeypatch.setattr( - "core.helper.ssrf_proxy.post", - attr_checker, - ) - result = node._run() - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs is not None - assert result.outputs["body"] == "" - - -@pytest.mark.skip( - reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - " - "needs rewrite for new architecture" -) -def test_http_request_node_form_with_multiple_files(monkeypatch): - data = HttpRequestNodeData( - title="test", - method="post", - url="http://example.org/upload", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="", - params="", - body=HttpRequestNodeBody( - type="form-data", - data=[ - BodyData( - key="files", - type="file", - file=["1111", "files"], - ), - BodyData( - key="name", - type="text", - value="test", - ), - ], - ), - ) - - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - ) - - files = [ - File( - tenant_id="1", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="file1", - filename="image1.jpg", - mime_type="image/jpeg", - storage_key="", - ), - File( - tenant_id="1", - type=FileType.DOCUMENT, - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="file2", - filename="document.pdf", - mime_type="application/pdf", - storage_key="", - ), - ] - - variable_pool.add( - ["1111", "files"], - ArrayFileVariable( - name="files", - value=files, - ), - ) - - node_config = { - "id": "1", - "data": data.model_dump(), - } - - node = HttpRequestNode( - id="1", - config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), - ) - - # Initialize node data - node.init_node_data(node_config["data"]) - - monkeypatch.setattr( - "core.workflow.nodes.http_request.executor.file_manager.download", - lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data", - ) - - def attr_checker(*args, **kwargs): - assert kwargs["data"] == {"name": "test"} - - assert len(kwargs["files"]) == 2 - assert kwargs["files"][0][0] == "files" - assert kwargs["files"][1][0] == "files" - - file_tuples = [f[1] for f in kwargs["files"]] - file_contents = [f[1] for f in file_tuples] - file_types = [f[2] for f in file_tuples] - - assert b"test_image_data" in file_contents - assert b"test_pdf_data" in file_contents - assert "image/jpeg" in file_types - assert "application/pdf" in file_types - - return httpx.Response(200, content=b'{"status":"success"}') - - monkeypatch.setattr( - "core.helper.ssrf_proxy.post", - attr_checker, - ) - - result = node._run() - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert result.outputs is not None - assert result.outputs["body"] == '{"status":"success"}' - print(result.outputs["body"]) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py deleted file mode 100644 index 5a7b3aad52..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ /dev/null @@ -1,909 +0,0 @@ -import time -import uuid -from unittest.mock import patch - -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables.segments import ArrayAnySegment, ArrayStringSegment -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.iteration.entities import ErrorHandleMode -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.node_factory import DifyNodeFactory -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom - - -@pytest.mark.skip( - reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" -) -def test_run(): - graph_config = { - "edges": [ - { - "id": "start-source-pe-target", - "source": "start", - "target": "pe", - }, - { - "id": "iteration-1-source-answer-3-target", - "source": "iteration-1", - "target": "answer-3", - }, - { - "id": "tt-source-if-else-target", - "source": "tt", - "target": "if-else", - }, - { - "id": "if-else-true-answer-2-target", - "source": "if-else", - "sourceHandle": "true", - "target": "answer-2", - }, - { - "id": "if-else-false-answer-4-target", - "source": "if-else", - "sourceHandle": "false", - "target": "answer-4", - }, - { - "id": "pe-source-iteration-1-target", - "source": "pe", - "target": "iteration-1", - }, - ], - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "tt", - "title": "iteration", - "type": "iteration", - }, - "id": "iteration-1", - }, - { - "data": { - "answer": "{{#tt.output#}}", - "iteration_id": "iteration-1", - "title": "answer 2", - "type": "answer", - }, - "id": "answer-2", - }, - { - "data": { - "iteration_id": "iteration-1", - "template": "{{ arg1 }} 123", - "title": "template transform", - "type": "template-transform", - "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], - }, - "id": "tt", - }, - { - "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, - "id": "answer-3", - }, - { - "data": { - "conditions": [ - { - "comparison_operator": "is", - "id": "1721916275284", - "value": "hi", - "variable_selector": ["sys", "query"], - } - ], - "iteration_id": "iteration-1", - "logical_operator": "and", - "title": "if", - "type": "if-else", - }, - "id": "if-else", - }, - { - "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, - "id": "answer-4", - }, - { - "data": { - "instruction": "test1", - "model": { - "completion_params": {"temperature": 0.7}, - "mode": "chat", - "name": "gpt-4o", - "provider": "openai", - }, - "parameters": [ - {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} - ], - "query": ["sys", "query"], - "reasoning_mode": "prompt", - "title": "pe", - "type": "parameter-extractor", - }, - "id": "pe", - }, - ], - } - - init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - # construct variable pool - pool = VariablePool( - system_variables=SystemVariable( - user_id="1", - files=[], - query="dify", - conversation_id="abababa", - ), - user_inputs={}, - environment_variables=[], - ) - pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) - - graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - - node_config = { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "tt", - "title": "迭代", - "type": "iteration", - }, - "id": "iteration-1", - } - - iteration_node = IterationNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - # Initialize node data - iteration_node.init_node_data(node_config["data"]) - - def tt_generator(self): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"iterator_selector": "dify"}, - outputs={"output": "dify 123"}, - ) - - with patch.object(TemplateTransformNode, "_run", new=tt_generator): - # execute node - result = iteration_node._run() - - count = 0 - for item in result: - # print(type(item), item) - count += 1 - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} - - assert count == 20 - - -@pytest.mark.skip( - reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" -) -def test_run_parallel(): - graph_config = { - "edges": [ - { - "id": "start-source-pe-target", - "source": "start", - "target": "pe", - }, - { - "id": "iteration-1-source-answer-3-target", - "source": "iteration-1", - "target": "answer-3", - }, - { - "id": "iteration-start-source-tt-target", - "source": "iteration-start", - "target": "tt", - }, - { - "id": "iteration-start-source-tt-2-target", - "source": "iteration-start", - "target": "tt-2", - }, - { - "id": "tt-source-if-else-target", - "source": "tt", - "target": "if-else", - }, - { - "id": "tt-2-source-if-else-target", - "source": "tt-2", - "target": "if-else", - }, - { - "id": "if-else-true-answer-2-target", - "source": "if-else", - "sourceHandle": "true", - "target": "answer-2", - }, - { - "id": "if-else-false-answer-4-target", - "source": "if-else", - "sourceHandle": "false", - "target": "answer-4", - }, - { - "id": "pe-source-iteration-1-target", - "source": "pe", - "target": "iteration-1", - }, - ], - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "iteration", - "type": "iteration", - }, - "id": "iteration-1", - }, - { - "data": { - "answer": "{{#tt.output#}}", - "iteration_id": "iteration-1", - "title": "answer 2", - "type": "answer", - }, - "id": "answer-2", - }, - { - "data": { - "iteration_id": "iteration-1", - "title": "iteration-start", - "type": "iteration-start", - }, - "id": "iteration-start", - }, - { - "data": { - "iteration_id": "iteration-1", - "template": "{{ arg1 }} 123", - "title": "template transform", - "type": "template-transform", - "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], - }, - "id": "tt", - }, - { - "data": { - "iteration_id": "iteration-1", - "template": "{{ arg1 }} 321", - "title": "template transform", - "type": "template-transform", - "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], - }, - "id": "tt-2", - }, - { - "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, - "id": "answer-3", - }, - { - "data": { - "conditions": [ - { - "comparison_operator": "is", - "id": "1721916275284", - "value": "hi", - "variable_selector": ["sys", "query"], - } - ], - "iteration_id": "iteration-1", - "logical_operator": "and", - "title": "if", - "type": "if-else", - }, - "id": "if-else", - }, - { - "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, - "id": "answer-4", - }, - { - "data": { - "instruction": "test1", - "model": { - "completion_params": {"temperature": 0.7}, - "mode": "chat", - "name": "gpt-4o", - "provider": "openai", - }, - "parameters": [ - {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} - ], - "query": ["sys", "query"], - "reasoning_mode": "prompt", - "title": "pe", - "type": "parameter-extractor", - }, - "id": "pe", - }, - ], - } - - init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - # construct variable pool - pool = VariablePool( - system_variables=SystemVariable( - user_id="1", - files=[], - query="dify", - conversation_id="abababa", - ), - user_inputs={}, - environment_variables=[], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) - - node_config = { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - }, - "id": "iteration-1", - } - - iteration_node = IterationNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - # Initialize node data - iteration_node.init_node_data(node_config["data"]) - - def tt_generator(self): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"iterator_selector": "dify"}, - outputs={"output": "dify 123"}, - ) - - with patch.object(TemplateTransformNode, "_run", new=tt_generator): - # execute node - result = iteration_node._run() - - count = 0 - for item in result: - count += 1 - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} - - assert count == 32 - - -@pytest.mark.skip( - reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" -) -def test_iteration_run_in_parallel_mode(): - graph_config = { - "edges": [ - { - "id": "start-source-pe-target", - "source": "start", - "target": "pe", - }, - { - "id": "iteration-1-source-answer-3-target", - "source": "iteration-1", - "target": "answer-3", - }, - { - "id": "iteration-start-source-tt-target", - "source": "iteration-start", - "target": "tt", - }, - { - "id": "iteration-start-source-tt-2-target", - "source": "iteration-start", - "target": "tt-2", - }, - { - "id": "tt-source-if-else-target", - "source": "tt", - "target": "if-else", - }, - { - "id": "tt-2-source-if-else-target", - "source": "tt-2", - "target": "if-else", - }, - { - "id": "if-else-true-answer-2-target", - "source": "if-else", - "sourceHandle": "true", - "target": "answer-2", - }, - { - "id": "if-else-false-answer-4-target", - "source": "if-else", - "sourceHandle": "false", - "target": "answer-4", - }, - { - "id": "pe-source-iteration-1-target", - "source": "pe", - "target": "iteration-1", - }, - ], - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "iteration", - "type": "iteration", - }, - "id": "iteration-1", - }, - { - "data": { - "answer": "{{#tt.output#}}", - "iteration_id": "iteration-1", - "title": "answer 2", - "type": "answer", - }, - "id": "answer-2", - }, - { - "data": { - "iteration_id": "iteration-1", - "title": "iteration-start", - "type": "iteration-start", - }, - "id": "iteration-start", - }, - { - "data": { - "iteration_id": "iteration-1", - "template": "{{ arg1 }} 123", - "title": "template transform", - "type": "template-transform", - "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], - }, - "id": "tt", - }, - { - "data": { - "iteration_id": "iteration-1", - "template": "{{ arg1 }} 321", - "title": "template transform", - "type": "template-transform", - "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], - }, - "id": "tt-2", - }, - { - "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, - "id": "answer-3", - }, - { - "data": { - "conditions": [ - { - "comparison_operator": "is", - "id": "1721916275284", - "value": "hi", - "variable_selector": ["sys", "query"], - } - ], - "iteration_id": "iteration-1", - "logical_operator": "and", - "title": "if", - "type": "if-else", - }, - "id": "if-else", - }, - { - "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, - "id": "answer-4", - }, - { - "data": { - "instruction": "test1", - "model": { - "completion_params": {"temperature": 0.7}, - "mode": "chat", - "name": "gpt-4o", - "provider": "openai", - }, - "parameters": [ - {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} - ], - "query": ["sys", "query"], - "reasoning_mode": "prompt", - "title": "pe", - "type": "parameter-extractor", - }, - "id": "pe", - }, - ], - } - - init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - # construct variable pool - pool = VariablePool( - system_variables=SystemVariable( - user_id="1", - files=[], - query="dify", - conversation_id="abababa", - ), - user_inputs={}, - environment_variables=[], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) - - parallel_node_config = { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - "is_parallel": True, - }, - "id": "iteration-1", - } - - parallel_iteration_node = IterationNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=parallel_node_config, - ) - - # Initialize node data - parallel_iteration_node.init_node_data(parallel_node_config["data"]) - sequential_node_config = { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "迭代", - "type": "iteration", - "is_parallel": True, - }, - "id": "iteration-1", - } - - sequential_iteration_node = IterationNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=sequential_node_config, - ) - - # Initialize node data - sequential_iteration_node.init_node_data(sequential_node_config["data"]) - - def tt_generator(self): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"iterator_selector": "dify"}, - outputs={"output": "dify 123"}, - ) - - with patch.object(TemplateTransformNode, "_run", new=tt_generator): - # execute node - parallel_result = parallel_iteration_node._run() - sequential_result = sequential_iteration_node._run() - assert parallel_iteration_node._node_data.parallel_nums == 10 - assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED - count = 0 - parallel_arr = [] - sequential_arr = [] - for item in parallel_result: - count += 1 - parallel_arr.append(item) - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} - assert count == 32 - - for item in sequential_result: - sequential_arr.append(item) - count += 1 - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])} - assert count == 64 - - -@pytest.mark.skip( - reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine" -) -def test_iteration_run_error_handle(): - graph_config = { - "edges": [ - { - "id": "start-source-pe-target", - "source": "start", - "target": "pe", - }, - { - "id": "iteration-1-source-answer-3-target", - "source": "iteration-1", - "target": "answer-3", - }, - { - "id": "tt-source-if-else-target", - "source": "iteration-start", - "target": "if-else", - }, - { - "id": "if-else-true-answer-2-target", - "source": "if-else", - "sourceHandle": "true", - "target": "tt", - }, - { - "id": "if-else-false-answer-4-target", - "source": "if-else", - "sourceHandle": "false", - "target": "tt2", - }, - { - "id": "pe-source-iteration-1-target", - "source": "pe", - "target": "iteration-1", - }, - ], - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt2", "output"], - "output_type": "array[string]", - "start_node_id": "if-else", - "title": "iteration", - "type": "iteration", - }, - "id": "iteration-1", - }, - { - "data": { - "iteration_id": "iteration-1", - "template": "{{ arg1.split(arg2) }}", - "title": "template transform", - "type": "template-transform", - "variables": [ - {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, - {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, - ], - }, - "id": "tt", - }, - { - "data": { - "iteration_id": "iteration-1", - "template": "{{ arg1 }}", - "title": "template transform", - "type": "template-transform", - "variables": [ - {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, - ], - }, - "id": "tt2", - }, - { - "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, - "id": "answer-3", - }, - { - "data": { - "iteration_id": "iteration-1", - "title": "iteration-start", - "type": "iteration-start", - }, - "id": "iteration-start", - }, - { - "data": { - "conditions": [ - { - "comparison_operator": "is", - "id": "1721916275284", - "value": "1", - "variable_selector": ["iteration-1", "item"], - } - ], - "iteration_id": "iteration-1", - "logical_operator": "and", - "title": "if", - "type": "if-else", - }, - "id": "if-else", - }, - { - "data": { - "instruction": "test1", - "model": { - "completion_params": {"temperature": 0.7}, - "mode": "chat", - "name": "gpt-4o", - "provider": "openai", - }, - "parameters": [ - {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} - ], - "query": ["sys", "query"], - "reasoning_mode": "prompt", - "title": "pe", - "type": "parameter-extractor", - }, - "id": "pe", - }, - ], - } - - init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - # construct variable pool - pool = VariablePool( - system_variables=SystemVariable( - user_id="1", - files=[], - query="dify", - conversation_id="abababa", - ), - user_inputs={}, - environment_variables=[], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - pool.add(["pe", "list_output"], ["1", "1"]) - error_node_config = { - "data": { - "iterator_selector": ["pe", "list_output"], - "output_selector": ["tt", "output"], - "output_type": "array[string]", - "startNodeType": "template-transform", - "start_node_id": "iteration-start", - "title": "iteration", - "type": "iteration", - "is_parallel": True, - "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, - }, - "id": "iteration-1", - } - - iteration_node = IterationNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=error_node_config, - ) - - # Initialize node data - iteration_node.init_node_data(error_node_config["data"]) - # execute continue on error node - result = iteration_node._run() - result_arr = [] - count = 0 - for item in result: - result_arr.append(item) - count += 1 - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[None, None])} - - assert count == 14 - # execute remove abnormal output - iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT - result = iteration_node._run() - count = 0 - for item in result: - count += 1 - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[])} - assert count == 14 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py deleted file mode 100644 index 3c5e75826f..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ /dev/null @@ -1,624 +0,0 @@ -import time -from unittest.mock import patch - -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool -from core.workflow.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( - GraphRunPartialSucceededEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunStreamChunkEvent, -) -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_factory import DifyNodeFactory -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom - - -class ContinueOnErrorTestHelper: - @staticmethod - def get_code_node( - code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} - ): - """Helper method to create a code node configuration""" - node = { - "id": "node", - "data": { - "outputs": {"result": {"type": "number"}}, - "error_strategy": error_strategy, - "title": "code", - "variables": [], - "code_language": "python3", - "code": "\n".join([line[4:] for line in code.split("\n")]), - "type": "code", - **retry_config, - }, - } - if default_value: - node["data"]["default_value"] = default_value - return node - - @staticmethod - def get_http_node( - error_strategy: str = "fail-branch", - default_value: dict | None = None, - authorization_success: bool = False, - retry_config: dict = {}, - ): - """Helper method to create a http node configuration""" - authorization = ( - { - "type": "api-key", - "config": { - "type": "basic", - "api_key": "ak-xxx", - "header": "api-key", - }, - } - if authorization_success - else { - "type": "api-key", - # missing config field - } - ) - node = { - "id": "node", - "data": { - "title": "http", - "desc": "", - "method": "get", - "url": "http://example.com", - "authorization": authorization, - "headers": "X-Header:123", - "params": "A:b", - "body": None, - "type": "http-request", - "error_strategy": error_strategy, - **retry_config, - }, - } - if default_value: - node["data"]["default_value"] = default_value - return node - - @staticmethod - def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): - """Helper method to create a http node configuration""" - node = { - "id": "node", - "data": { - "type": "http-request", - "title": "HTTP Request", - "desc": "", - "variables": [], - "method": "get", - "url": "https://api.github.com/issues", - "authorization": {"type": "no-auth", "config": None}, - "headers": "", - "params": "", - "body": {"type": "none", "data": []}, - "timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0}, - "error_strategy": error_strategy, - }, - } - if default_value: - node["data"]["default_value"] = default_value - return node - - @staticmethod - def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): - """Helper method to create a tool node configuration""" - node = { - "id": "node", - "data": { - "title": "a", - "desc": "a", - "provider_id": "maths", - "provider_type": "builtin", - "provider_name": "maths", - "tool_name": "eval_expression", - "tool_label": "eval_expression", - "tool_configurations": {}, - "tool_parameters": { - "expression": { - "type": "variable", - "value": ["1", "123", "args1"], - } - }, - "type": "tool", - "error_strategy": error_strategy, - }, - } - if default_value: - node.node_data.default_value = default_value - return node - - @staticmethod - def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None): - """Helper method to create a llm node configuration""" - node = { - "id": "node", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_template": [ - {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, - "error_strategy": error_strategy, - }, - } - if default_value: - node["data"]["default_value"] = default_value - return node - - @staticmethod - def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): - """Helper method to create a graph engine instance for testing""" - # Create graph initialization parameters - init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="aaa", - files=[], - query="clear", - conversation_id="abababa", - ), - user_inputs=user_inputs or {"uid": "takato"}, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory(init_params, graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - - return GraphEngine( - tenant_id="111", - app_id="222", - workflow_id="333", - graph_config=graph_config, - user_id="444", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - graph=graph, - graph_runtime_state=graph_runtime_state, - max_execution_steps=500, - max_execution_time=1200, - command_channel=InMemoryChannel(), - ) - - -DEFAULT_VALUE_EDGE = [ - { - "id": "start-source-node-target", - "source": "start", - "target": "node", - "sourceHandle": "source", - }, - { - "id": "node-source-answer-target", - "source": "node", - "target": "answer", - "sourceHandle": "source", - }, -] - -FAIL_BRANCH_EDGES = [ - { - "id": "start-source-node-target", - "source": "start", - "target": "node", - "sourceHandle": "source", - }, - { - "id": "node-true-success-target", - "source": "node", - "target": "success", - "sourceHandle": "source", - }, - { - "id": "node-false-error-target", - "source": "node", - "target": "error", - "sourceHandle": "fail-branch", - }, -] - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_code_default_value_continue_on_error(): - error_code = """ - def main() -> dict: - return { - "result": 1 / 0, - } - """ - - graph_config = { - "edges": DEFAULT_VALUE_EDGE, - "nodes": [ - {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, - {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_code_node( - error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] - ), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_code_fail_branch_continue_on_error(): - error_code = """ - def main() -> dict: - return { - "result": 1 / 0, - } - """ - - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, - "id": "error", - }, - ContinueOnErrorTestHelper.get_code_node(error_code), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any( - isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events - ) - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_http_node_default_value_continue_on_error(): - """Test HTTP node with default value error strategy""" - graph_config = { - "edges": DEFAULT_VALUE_EDGE, - "nodes": [ - {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, - {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_http_node( - "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] - ), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any( - isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"} - for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_http_node_fail_branch_continue_on_error(): - """Test HTTP node with fail-branch error strategy""" - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "HTTP request failed"}, - "id": "error", - }, - ContinueOnErrorTestHelper.get_http_node(), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any( - isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -# def test_tool_node_default_value_continue_on_error(): -# """Test tool node with default value error strategy""" -# graph_config = { -# "edges": DEFAULT_VALUE_EDGE, -# "nodes": [ -# {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, -# {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, -# ContinueOnErrorTestHelper.get_tool_node( -# "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] -# ), -# ], -# } - -# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) -# events = list(graph_engine.run()) - -# assert any(isinstance(e, NodeRunExceptionEvent) for e in events) -# assert any( -# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events # noqa: E501 -# ) -# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -# def test_tool_node_fail_branch_continue_on_error(): -# """Test HTTP node with fail-branch error strategy""" -# graph_config = { -# "edges": FAIL_BRANCH_EDGES, -# "nodes": [ -# {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, -# { -# "data": {"title": "success", "type": "answer", "answer": "tool execute successful"}, -# "id": "success", -# }, -# { -# "data": {"title": "error", "type": "answer", "answer": "tool execute failed"}, -# "id": "error", -# }, -# ContinueOnErrorTestHelper.get_tool_node(), -# ], -# } - -# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) -# events = list(graph_engine.run()) - -# assert any(isinstance(e, NodeRunExceptionEvent) for e in events) -# assert any( -# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events # noqa: E501 -# ) -# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_llm_node_default_value_continue_on_error(): - """Test LLM node with default value error strategy""" - graph_config = { - "edges": DEFAULT_VALUE_EDGE, - "nodes": [ - {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, - {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_llm_node( - "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] - ), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any( - isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_llm_node_fail_branch_continue_on_error(): - """Test LLM node with fail-branch error strategy""" - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "LLM request failed"}, - "id": "error", - }, - ContinueOnErrorTestHelper.get_llm_node(), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any( - isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_status_code_error_http_node_fail_branch_continue_on_error(): - """Test HTTP node with fail-branch error strategy""" - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, - "id": "error", - }, - ContinueOnErrorTestHelper.get_error_status_code_http_node(), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any( - isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_variable_pool_error_type_variable(): - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, - "id": "error", - }, - ContinueOnErrorTestHelper.get_error_status_code_http_node(), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - list(graph_engine.run()) - error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) - error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) - assert error_message != None - assert error_type.value == "HTTPResponseCodeError" - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_no_node_in_fail_branch_continue_on_error(): - """Test HTTP node with fail-branch error strategy""" - graph_config = { - "edges": FAIL_BRANCH_EDGES[:-1], - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - {"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"}, - ContinueOnErrorTestHelper.get_http_node(), - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 - - -@pytest.mark.skip( - reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - " - "not fully implemented in MVP of queue-based engine" -) -def test_stream_output_with_fail_branch_continue_on_error(): - """Test stream output with fail-branch error strategy""" - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"}, - "id": "error", - }, - ContinueOnErrorTestHelper.get_llm_node(), - ], - } - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - - def llm_generator(self): - contents = ["hi", "bye", "good morning"] - - yield NodeRunStreamChunkEvent( - node_id=self.node_id, - node_type=self._node_type, - selector=[self.node_id, "text"], - chunk=contents[0], - is_final=False, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={}, - process_data={}, - outputs={}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 1, - WorkflowNodeExecutionMetadataKey.CURRENCY: "USD", - }, - ) - ) - - with patch.object(LLMNode, "_run", new=llm_generator): - events = list(graph_engine.run()) - assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1 - assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py deleted file mode 100644 index f4dc4477de..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ /dev/null @@ -1,116 +0,0 @@ -from collections.abc import Generator - -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType -from core.tools.errors import ToolInvokeError -from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool -from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute -from core.workflow.nodes.end.entities import EndStreamParam -from core.workflow.nodes.tool import ToolNode -from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.system_variable import SystemVariable -from models import UserFrom - - -def _create_tool_node(): - data = ToolNodeData( - title="Test Tool", - tool_parameters={}, - provider_id="test_tool", - provider_type=ToolProviderType.WORKFLOW, - provider_name="test tool", - tool_name="test tool", - tool_label="test tool", - tool_configurations={}, - plugin_unique_identifier=None, - desc="Exception handling test tool", - error_strategy=ErrorStrategy.FAIL_BRANCH, - version="1", - ) - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - ) - node_config = { - "id": "1", - "data": data.model_dump(), - } - node = ToolNode( - id="1", - config=node_config, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), - ) - # Initialize node data - node.init_node_data(node_config["data"]) - return node - - -class MockToolRuntime: - def get_merged_runtime_parameters(self): - pass - - -def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]: - yield from [] - raise ToolInvokeError("oops") - - -@pytest.mark.skip( - reason="Tool node test uses old Graph constructor incompatible with new queue-based engine - " - "needs rewrite for new architecture" -) -def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch): - """Ensure that ToolNode can handle ToolInvokeError when transforming - messages generated by ToolEngine.generic_invoke. - """ - tool_node = _create_tool_node() - - # Need to patch ToolManager and ToolEngine so that we don't - # have to set up a database. - monkeypatch.setattr( - "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime() - ) - monkeypatch.setattr( - "core.tools.tool_engine.ToolEngine.generic_invoke", - lambda *args, **kwargs: mock_message_stream(), - ) - - streams = list(tool_node._run()) - assert len(streams) == 1 - stream = streams[0] - assert isinstance(stream, StreamCompletedEvent) - result = stream.node_run_result - assert isinstance(result, NodeRunResult) - assert result.status == WorkflowNodeExecutionStatus.FAILED - assert "oops" in result.error - assert "Failed to invoke tool" in result.error - assert result.error_type == "ToolInvokeError" diff --git a/sdks/nodejs-client/index.d.ts b/sdks/nodejs-client/index.d.ts index a8b7497f4f..3ea4b9d153 100644 --- a/sdks/nodejs-client/index.d.ts +++ b/sdks/nodejs-client/index.d.ts @@ -14,6 +14,22 @@ interface HeaderParams { interface User { } +interface DifyFileBase { + type: "image" +} + +export interface DifyRemoteFile extends DifyFileBase { + transfer_method: "remote_url" + url: string +} + +export interface DifyLocalFile extends DifyFileBase { + transfer_method: "local_file" + upload_file_id: string +} + +export type DifyFile = DifyRemoteFile | DifyLocalFile; + export declare class DifyClient { constructor(apiKey: string, baseUrl?: string); @@ -44,7 +60,7 @@ export declare class CompletionClient extends DifyClient { inputs: any, user: User, stream?: boolean, - files?: File[] | null + files?: DifyFile[] | null ): Promise; } @@ -55,7 +71,7 @@ export declare class ChatClient extends DifyClient { user: User, stream?: boolean, conversation_id?: string | null, - files?: File[] | null + files?: DifyFile[] | null ): Promise; getSuggested(message_id: string, user: User): Promise; diff --git a/web/app/components/share/utils.ts b/web/app/components/share/utils.ts index 0c6457fb0c..3f5303dfcc 100644 --- a/web/app/components/share/utils.ts +++ b/web/app/components/share/utils.ts @@ -32,6 +32,7 @@ export const checkOrSetAccessToken = async (appCode?: string | null) => { [userId || 'DEFAULT']: res.access_token, } localStorage.setItem('token', JSON.stringify(accessTokenJson)) + localStorage.removeItem(CONVERSATION_ID_INFO) } } diff --git a/web/context/web-app-context.tsx b/web/context/web-app-context.tsx index deb7aea53c..0fe1b56b0a 100644 --- a/web/context/web-app-context.tsx +++ b/web/context/web-app-context.tsx @@ -11,6 +11,7 @@ import type { FC, PropsWithChildren } from 'react' import { useEffect } from 'react' import { useState } from 'react' import { create } from 'zustand' +import { useGlobalPublicStore } from './global-public-context' type WebAppStore = { shareCode: string | null @@ -56,6 +57,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => { } const WebAppStoreProvider: FC = ({ children }) => { + const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending) const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode) const updateShareCode = useWebAppStore(state => state.updateShareCode) const pathname = usePathname() @@ -69,7 +71,7 @@ const WebAppStoreProvider: FC = ({ children }) => { }, [shareCode, updateShareCode]) const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode) - const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(false) + const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(true) useEffect(() => { if (accessModeResult?.accessMode) { @@ -86,7 +88,7 @@ const WebAppStoreProvider: FC = ({ children }) => { } }, [accessModeResult, updateWebAppAccessMode, shareCode]) - if (isFetching || isFetchingAccessToken) { + if (isGlobalPending || isFetching || isFetchingAccessToken) { return
diff --git a/web/service/base.ts b/web/service/base.ts index 98b76f072a..6c272d6a25 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -430,9 +430,7 @@ export const ssePost = async ( .then((res) => { if (!/^[23]\d{2}$/.test(String(res.status))) { if (res.status === 401) { - refreshAccessTokenOrRelogin(TIME_OUT).then(() => { - ssePost(url, fetchOptions, otherOptions) - }).catch(() => { + if (isPublicAPI) { res.json().then((data: any) => { if (isPublicAPI) { if (data.code === 'web_app_access_denied') @@ -449,7 +447,14 @@ export const ssePost = async ( } } }) - }) + } + else { + refreshAccessTokenOrRelogin(TIME_OUT).then(() => { + ssePost(url, fetchOptions, otherOptions) + }).catch((err) => { + console.error(err) + }) + } } else { res.json().then((data) => { diff --git a/web/service/use-share.ts b/web/service/use-share.ts index 63f18bf0e0..6845a2f3c7 100644 --- a/web/service/use-share.ts +++ b/web/service/use-share.ts @@ -1,20 +1,12 @@ -import { useGlobalPublicStore } from '@/context/global-public-context' -import { AccessMode } from '@/models/access-control' import { useQuery } from '@tanstack/react-query' import { fetchAppInfo, fetchAppMeta, fetchAppParams, getAppAccessModeByAppCode } from './share' const NAME_SPACE = 'webapp' export const useGetWebAppAccessModeByCode = (code: string | null) => { - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) return useQuery({ queryKey: [NAME_SPACE, 'appAccessMode', code], queryFn: () => { - if (systemFeatures.webapp_auth.enabled === false) { - return { - accessMode: AccessMode.PUBLIC, - } - } if (!code || code.length === 0) return Promise.reject(new Error('App code is required to get access mode'))