mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
commit
1db04aa729
|
|
@ -3,6 +3,7 @@ import logging
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
|
|
@ -94,21 +95,18 @@ class ChatMessageListApi(Resource):
|
|||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
|
||||
has_more = db.session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
)
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
|
|
|
|||
|
|
@ -8,20 +8,21 @@ from uuid import UUID
|
|||
|
||||
import numpy as np
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_json_value(v):
|
||||
if isinstance(v, datetime):
|
||||
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
|
||||
if not tz_name:
|
||||
tz_name = "UTC"
|
||||
tz_name = "UTC"
|
||||
if isinstance(current_user, Account) and current_user.timezone is not None:
|
||||
tz_name = current_user.timezone
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
elif isinstance(v, date):
|
||||
return v.isoformat()
|
||||
|
|
@ -46,7 +47,7 @@ def safe_json_value(v):
|
|||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d):
|
||||
def safe_json_dict(d: dict):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ import logging
|
|||
from collections.abc import Generator
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
|
|
@ -17,8 +15,8 @@ from core.tools.entities.tool_entities import (
|
|||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from libs.login import current_user
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -79,11 +77,11 @@ class WorkflowTool(Tool):
|
|||
generator = WorkflowAppGenerator()
|
||||
assert self.runtime is not None
|
||||
assert self.runtime.invoke_from is not None
|
||||
|
||||
assert current_user is not None
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=cast("Account | EndUser", current_user),
|
||||
user=current_user,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class NodeExecutionType(StrEnum):
|
|||
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||
ROOT = "root" # Nodes that can serve as execution entry points
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Protocol, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .edge import Edge
|
||||
|
|
@ -36,10 +36,10 @@ class Graph:
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: Optional[dict[str, Node]] = None,
|
||||
edges: Optional[dict[str, Edge]] = None,
|
||||
in_edges: Optional[dict[str, list[str]]] = None,
|
||||
out_edges: Optional[dict[str, list[str]]] = None,
|
||||
nodes: dict[str, Node] | None = None,
|
||||
edges: dict[str, Edge] | None = None,
|
||||
in_edges: dict[str, list[str]] | None = None,
|
||||
out_edges: dict[str, list[str]] | None = None,
|
||||
root_node: Node,
|
||||
):
|
||||
"""
|
||||
|
|
@ -81,7 +81,7 @@ class Graph:
|
|||
cls,
|
||||
node_configs_map: dict[str, dict[str, Any]],
|
||||
edge_configs: list[dict[str, Any]],
|
||||
root_node_id: Optional[str] = None,
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find the root node ID if not specified.
|
||||
|
|
@ -186,13 +186,79 @@ class Graph:
|
|||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
nodes: dict[str, Node],
|
||||
edges: dict[str, Edge],
|
||||
in_edges: dict[str, list[str]],
|
||||
out_edges: dict[str, list[str]],
|
||||
active_root_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Mark nodes and edges from inactive root branches as skipped.
|
||||
|
||||
Algorithm:
|
||||
1. Mark inactive root nodes as skipped
|
||||
2. For skipped nodes, mark all their outgoing edges as skipped
|
||||
3. For each edge marked as skipped, check its target node:
|
||||
- If ALL incoming edges are skipped, mark the node as skipped
|
||||
- Otherwise, leave the node state unchanged
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
:param edges: mapping of edge ID to edge instance
|
||||
:param in_edges: mapping of node ID to incoming edge IDs
|
||||
:param out_edges: mapping of node ID to outgoing edge IDs
|
||||
:param active_root_id: ID of the active root node
|
||||
"""
|
||||
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
||||
top_level_roots: list[str] = [
|
||||
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
||||
]
|
||||
|
||||
# If there's only one root or the active root is not a top-level root, no marking needed
|
||||
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
||||
return
|
||||
|
||||
# Mark inactive root nodes as skipped
|
||||
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
||||
for root_id in inactive_roots:
|
||||
if root_id in nodes:
|
||||
nodes[root_id].state = NodeState.SKIPPED
|
||||
|
||||
# Recursively mark downstream nodes and edges
|
||||
def mark_downstream(node_id: str) -> None:
|
||||
"""Recursively mark downstream nodes and edges as skipped."""
|
||||
if nodes[node_id].state != NodeState.SKIPPED:
|
||||
return
|
||||
# If this node is skipped, mark all its outgoing edges as skipped
|
||||
out_edge_ids = out_edges.get(node_id, [])
|
||||
for edge_id in out_edge_ids:
|
||||
edge = edges[edge_id]
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
# Check the target node of this edge
|
||||
target_node = nodes[edge.head]
|
||||
in_edge_ids = in_edges.get(target_node.id, [])
|
||||
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
||||
|
||||
# If all incoming edges are skipped, mark the node as skipped
|
||||
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
||||
target_node.state = NodeState.SKIPPED
|
||||
# Recursively process downstream nodes
|
||||
mark_downstream(target_node.id)
|
||||
|
||||
# Process each inactive root and its downstream nodes
|
||||
for root_id in inactive_roots:
|
||||
mark_downstream(root_id)
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_factory: "NodeFactory",
|
||||
root_node_id: Optional[str] = None,
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
"""
|
||||
Initialize graph
|
||||
|
|
@ -227,6 +293,9 @@ class Graph:
|
|||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Mark inactive root branches as skipped
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@ within a single process. Each instance handles commands for one workflow executi
|
|||
"""
|
||||
|
||||
from queue import Queue
|
||||
from typing import final
|
||||
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
||||
|
||||
@final
|
||||
class InMemoryChannel:
|
||||
"""
|
||||
In-memory command channel implementation using a thread-safe queue.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Each instance uses a unique key for its command queue.
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
|||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
class RedisChannel:
|
||||
"""
|
||||
Redis-based command channel implementation for distributed systems.
|
||||
|
|
@ -86,7 +87,7 @@ class RedisChannel:
|
|||
pipe.expire(self._key, self._command_ttl)
|
||||
pipe.execute()
|
||||
|
||||
def _deserialize_command(self, data: dict) -> Optional[GraphEngineCommand]:
|
||||
def _deserialize_command(self, data: dict) -> GraphEngineCommand | None:
|
||||
"""
|
||||
Deserialize a command from dictionary data.
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Command handler implementations.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
|
|
@ -11,6 +12,7 @@ from .command_processor import CommandHandler
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Main command processor for handling external commands.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Protocol
|
||||
from typing import Protocol, final
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import GraphEngineCommand
|
||||
|
|
@ -18,6 +18,7 @@ class CommandHandler(Protocol):
|
|||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
|
||||
|
||||
|
||||
@final
|
||||
class CommandProcessor:
|
||||
"""
|
||||
Processes external commands sent to the engine.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ GraphExecution aggregate root managing the overall graph execution state.
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .node_execution import NodeExecution
|
||||
|
||||
|
|
@ -21,7 +20,7 @@ class GraphExecution:
|
|||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
error: Optional[Exception] = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
|
||||
|
||||
def start(self) -> None:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ NodeExecution entity representing a node's execution state.
|
|||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
|
@ -20,8 +19,8 @@ class NodeExecution:
|
|||
node_id: str
|
||||
state: NodeState = NodeState.UNKNOWN
|
||||
retry_count: int = 0
|
||||
execution_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
execution_id: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
def mark_started(self, execution_id: str) -> None:
|
||||
"""Mark the node as started with an execution ID."""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ instance to control its execution flow.
|
|||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -23,11 +23,11 @@ class GraphEngineCommand(BaseModel):
|
|||
"""Base class for all GraphEngine commands."""
|
||||
|
||||
command_type: CommandType = Field(..., description="Type of command")
|
||||
payload: Optional[dict[str, Any]] = Field(default=None, description="Optional command payload")
|
||||
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
|
||||
|
||||
|
||||
class AbortCommand(GraphEngineCommand):
|
||||
"""Command to abort a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: Optional[str] = Field(default=None, description="Optional reason for abort")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ the Strategy pattern for clean separation of concerns.
|
|||
from .abort_strategy import AbortStrategy
|
||||
from .default_value_strategy import DefaultValueStrategy
|
||||
from .error_handler import ErrorHandler
|
||||
from .error_strategy import ErrorStrategy
|
||||
from .fail_branch_strategy import FailBranchStrategy
|
||||
from .retry_strategy import RetryStrategy
|
||||
|
||||
|
|
@ -16,7 +15,6 @@ __all__ = [
|
|||
"AbortStrategy",
|
||||
"DefaultValueStrategy",
|
||||
"ErrorHandler",
|
||||
"ErrorStrategy",
|
||||
"FailBranchStrategy",
|
||||
"RetryStrategy",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Abort error strategy implementation.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
|
|
@ -11,6 +11,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class AbortStrategy:
|
||||
"""
|
||||
Error strategy that aborts execution on failure.
|
||||
|
|
@ -19,7 +20,7 @@ class AbortStrategy:
|
|||
It stops the entire graph execution when a node fails.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by aborting execution.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Default value error strategy implementation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
|
|
@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
|
|||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
@final
|
||||
class DefaultValueStrategy:
|
||||
"""
|
||||
Error strategy that uses default values on failure.
|
||||
|
|
@ -18,7 +19,7 @@ class DefaultValueStrategy:
|
|||
predefined default output values.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by using default values.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Main error handler that coordinates error strategies.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
|
||||
from core.workflow.graph import Graph
|
||||
|
|
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||
from ..domain import GraphExecution
|
||||
|
||||
|
||||
@final
|
||||
class ErrorHandler:
|
||||
"""
|
||||
Coordinates error handling strategies for node failures.
|
||||
|
|
@ -34,16 +35,16 @@ class ErrorHandler:
|
|||
graph: The workflow graph
|
||||
graph_execution: The graph execution state
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_execution = graph_execution
|
||||
self._graph = graph
|
||||
self._graph_execution = graph_execution
|
||||
|
||||
# Initialize strategies
|
||||
self.abort_strategy = AbortStrategy()
|
||||
self.retry_strategy = RetryStrategy()
|
||||
self.fail_branch_strategy = FailBranchStrategy()
|
||||
self.default_value_strategy = DefaultValueStrategy()
|
||||
self._abort_strategy = AbortStrategy()
|
||||
self._retry_strategy = RetryStrategy()
|
||||
self._fail_branch_strategy = FailBranchStrategy()
|
||||
self._default_value_strategy = DefaultValueStrategy()
|
||||
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> Optional[GraphNodeEventBase]:
|
||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
|
|
@ -56,14 +57,14 @@ class ErrorHandler:
|
|||
Returns:
|
||||
Optional new event to process, or None to abort
|
||||
"""
|
||||
node = self.graph.nodes[event.node_id]
|
||||
node = self._graph.nodes[event.node_id]
|
||||
# Get retry count from NodeExecution
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
retry_count = node_execution.retry_count
|
||||
|
||||
# First check if retry is configured and not exhausted
|
||||
if node.retry and retry_count < node.retry_config.max_retries:
|
||||
result = self.retry_strategy.handle_error(event, self.graph, retry_count)
|
||||
result = self._retry_strategy.handle_error(event, self._graph, retry_count)
|
||||
if result:
|
||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||
return result
|
||||
|
|
@ -71,12 +72,10 @@ class ErrorHandler:
|
|||
# Apply configured error strategy
|
||||
strategy = node.error_strategy
|
||||
|
||||
if strategy is None:
|
||||
return self.abort_strategy.handle_error(event, self.graph, retry_count)
|
||||
elif strategy == ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self.fail_branch_strategy.handle_error(event, self.graph, retry_count)
|
||||
elif strategy == ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self.default_value_strategy.handle_error(event, self.graph, retry_count)
|
||||
else:
|
||||
# Unknown strategy, default to abort
|
||||
return self.abort_strategy.handle_error(event, self.graph, retry_count)
|
||||
match strategy:
|
||||
case None:
|
||||
return self._abort_strategy.handle_error(event, self._graph, retry_count)
|
||||
case ErrorStrategyEnum.FAIL_BRANCH:
|
||||
return self._fail_branch_strategy.handle_error(event, self._graph, retry_count)
|
||||
case ErrorStrategyEnum.DEFAULT_VALUE:
|
||||
return self._default_value_strategy.handle_error(event, self._graph, retry_count)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Fail branch error strategy implementation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
|
|
@ -10,6 +10,7 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent
|
|||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
@final
|
||||
class FailBranchStrategy:
|
||||
"""
|
||||
Error strategy that continues execution via a fail branch.
|
||||
|
|
@ -18,7 +19,7 @@ class FailBranchStrategy:
|
|||
through a designated fail-branch edge.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by taking the fail branch.
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@ Retry error strategy implementation.
|
|||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent
|
||||
|
||||
|
||||
@final
|
||||
class RetryStrategy:
|
||||
"""
|
||||
Error strategy that retries failed nodes.
|
||||
|
|
@ -17,7 +18,7 @@ class RetryStrategy:
|
|||
maximum number of retries with configurable intervals.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle error by retrying the node.
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,92 @@ Event collector for buffering and managing events.
|
|||
"""
|
||||
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import Layer
|
||||
|
||||
|
||||
@final
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock implementation that allows multiple concurrent readers
|
||||
but only one writer at a time.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._read_ready = threading.Condition(threading.RLock())
|
||||
self._readers = 0
|
||||
|
||||
def acquire_read(self) -> None:
|
||||
"""Acquire a read lock."""
|
||||
self._read_ready.acquire()
|
||||
try:
|
||||
self._readers += 1
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
self._read_ready.acquire()
|
||||
try:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._read_ready.notify_all()
|
||||
finally:
|
||||
self._read_ready.release()
|
||||
|
||||
def acquire_write(self) -> None:
|
||||
"""Acquire a write lock."""
|
||||
self._read_ready.acquire()
|
||||
while self._readers > 0:
|
||||
self._read_ready.wait()
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release a write lock."""
|
||||
self._read_ready.release()
|
||||
|
||||
def read_lock(self) -> "ReadLockContext":
|
||||
"""Return a context manager for read locking."""
|
||||
return ReadLockContext(self)
|
||||
|
||||
def write_lock(self) -> "WriteLockContext":
|
||||
"""Return a context manager for write locking."""
|
||||
return WriteLockContext(self)
|
||||
|
||||
|
||||
@final
|
||||
class ReadLockContext:
|
||||
"""Context manager for read locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "ReadLockContext":
|
||||
self._lock.acquire_read()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_read()
|
||||
|
||||
|
||||
@final
|
||||
class WriteLockContext:
|
||||
"""Context manager for write locks."""
|
||||
|
||||
def __init__(self, lock: ReadWriteLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def __enter__(self) -> "WriteLockContext":
|
||||
self._lock.acquire_write()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
|
||||
self._lock.release_write()
|
||||
|
||||
|
||||
@final
|
||||
class EventCollector:
|
||||
"""
|
||||
Collects and buffers events for later retrieval.
|
||||
|
|
@ -20,7 +100,7 @@ class EventCollector:
|
|||
def __init__(self) -> None:
|
||||
"""Initialize the event collector."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = threading.Lock()
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
def set_layers(self, layers: list[Layer]) -> None:
|
||||
|
|
@ -39,7 +119,7 @@ class EventCollector:
|
|||
Args:
|
||||
event: The event to collect
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.write_lock():
|
||||
self._events.append(event)
|
||||
self._notify_layers(event)
|
||||
|
||||
|
|
@ -50,7 +130,7 @@ class EventCollector:
|
|||
Returns:
|
||||
List of collected events
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.read_lock():
|
||||
return list(self._events)
|
||||
|
||||
def get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
|
||||
|
|
@ -63,7 +143,7 @@ class EventCollector:
|
|||
Returns:
|
||||
List of new events
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.read_lock():
|
||||
return list(self._events[start_index:])
|
||||
|
||||
def event_count(self) -> int:
|
||||
|
|
@ -73,12 +153,12 @@ class EventCollector:
|
|||
Returns:
|
||||
Number of collected events
|
||||
"""
|
||||
with self._lock:
|
||||
with self._lock.read_lock():
|
||||
return len(self._events)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all collected events."""
|
||||
with self._lock:
|
||||
with self._lock.write_lock():
|
||||
self._events.clear()
|
||||
|
||||
def _notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
|
|
|
|||
|
|
@ -5,12 +5,14 @@ Event emitter for yielding events to external consumers.
|
|||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from .event_collector import EventCollector
|
||||
|
||||
|
||||
@final
|
||||
class EventEmitter:
|
||||
"""
|
||||
Emits collected events as a generator for external consumption.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Event handler implementations for different event types.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
|
|
@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class EventHandlerRegistry:
|
||||
"""
|
||||
Registry of event handlers for different event types.
|
||||
|
|
@ -52,12 +53,12 @@ class EventHandlerRegistry:
|
|||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_execution: GraphExecution,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
event_collector: Optional["EventCollector"] = None,
|
||||
branch_handler: Optional["BranchHandler"] = None,
|
||||
edge_processor: Optional["EdgeProcessor"] = None,
|
||||
node_state_manager: Optional["NodeStateManager"] = None,
|
||||
execution_tracker: Optional["ExecutionTracker"] = None,
|
||||
error_handler: Optional["ErrorHandler"] = None,
|
||||
event_collector: "EventCollector",
|
||||
branch_handler: "BranchHandler",
|
||||
edge_processor: "EdgeProcessor",
|
||||
node_state_manager: "NodeStateManager",
|
||||
execution_tracker: "ExecutionTracker",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the event handler registry.
|
||||
|
|
@ -67,23 +68,23 @@ class EventHandlerRegistry:
|
|||
graph_runtime_state: Runtime state with variable pool
|
||||
graph_execution: Graph execution aggregate
|
||||
response_coordinator: Response stream coordinator
|
||||
event_collector: Optional event collector for collecting events
|
||||
branch_handler: Optional branch handler for branch node processing
|
||||
edge_processor: Optional edge processor for edge traversal
|
||||
node_state_manager: Optional node state manager
|
||||
execution_tracker: Optional execution tracker
|
||||
error_handler: Optional error handler
|
||||
event_collector: Event collector for collecting events
|
||||
branch_handler: Branch handler for branch node processing
|
||||
edge_processor: Edge processor for edge traversal
|
||||
node_state_manager: Node state manager
|
||||
execution_tracker: Execution tracker
|
||||
error_handler: Error handler
|
||||
"""
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.graph_execution = graph_execution
|
||||
self.response_coordinator = response_coordinator
|
||||
self.event_collector = event_collector
|
||||
self.branch_handler = branch_handler
|
||||
self.edge_processor = edge_processor
|
||||
self.node_state_manager = node_state_manager
|
||||
self.execution_tracker = execution_tracker
|
||||
self.error_handler = error_handler
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._event_collector = event_collector
|
||||
self._branch_handler = branch_handler
|
||||
self._edge_processor = edge_processor
|
||||
self._node_state_manager = node_state_manager
|
||||
self._execution_tracker = execution_tracker
|
||||
self._error_handler = error_handler
|
||||
|
||||
def handle_event(self, event: GraphNodeEventBase) -> None:
|
||||
"""
|
||||
|
|
@ -93,9 +94,8 @@ class EventHandlerRegistry:
|
|||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if isinstance(event, GraphNodeEventBase) and (event.in_loop_id or event.in_iteration_id):
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
|
||||
# Handle specific event types
|
||||
|
|
@ -125,12 +125,10 @@ class EventHandlerRegistry:
|
|||
),
|
||||
):
|
||||
# Iteration and loop events are collected directly
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
else:
|
||||
# Collect unhandled events
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
|
|
@ -141,15 +139,14 @@ class EventHandlerRegistry:
|
|||
event: The node started event
|
||||
"""
|
||||
# Track execution in domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_started(event.id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self.response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
|
|
@ -159,12 +156,11 @@ class EventHandlerRegistry:
|
|||
event: The stream chunk event
|
||||
"""
|
||||
# Process with response coordinator
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
streaming_events = list(self._response_coordinator.intercept_event(event))
|
||||
|
||||
# Collect all events
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
|
|
@ -177,55 +173,44 @@ class EventHandlerRegistry:
|
|||
event: The node succeeded event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event)
|
||||
|
||||
# Forward to response coordinator and emit streaming events
|
||||
streaming_events = list(self.response_coordinator.intercept_event(event))
|
||||
if self.event_collector:
|
||||
for stream_event in streaming_events:
|
||||
self.event_collector.collect(stream_event)
|
||||
streaming_events = self._response_coordinator.intercept_event(event)
|
||||
for stream_event in streaming_events:
|
||||
self._event_collector.collect(stream_event)
|
||||
|
||||
# Process edges and get ready nodes
|
||||
node = self.graph.nodes[event.node_id]
|
||||
node = self._graph.nodes[event.node_id]
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
if self.branch_handler:
|
||||
ready_nodes, edge_streaming_events = self.branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
ready_nodes, edge_streaming_events = self._branch_handler.handle_branch_completion(
|
||||
event.node_id, event.node_run_result.edge_source_handle
|
||||
)
|
||||
else:
|
||||
if self.edge_processor:
|
||||
ready_nodes, edge_streaming_events = self.edge_processor.process_node_success(event.node_id)
|
||||
else:
|
||||
ready_nodes, edge_streaming_events = [], []
|
||||
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
|
||||
|
||||
# Collect streaming events from edge processing
|
||||
if self.event_collector:
|
||||
for edge_event in edge_streaming_events:
|
||||
self.event_collector.collect(edge_event)
|
||||
for edge_event in edge_streaming_events:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
if self.node_state_manager and self.execution_tracker:
|
||||
for node_id in ready_nodes:
|
||||
self.node_state_manager.enqueue_node(node_id)
|
||||
self.execution_tracker.add(node_id)
|
||||
for node_id in ready_nodes:
|
||||
self._node_state_manager.enqueue_node(node_id)
|
||||
self._execution_tracker.add(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
|
||||
# Handle response node outputs
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._update_response_outputs(event)
|
||||
|
||||
# Collect the event
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
|
|
@ -235,29 +220,19 @@ class EventHandlerRegistry:
|
|||
event: The node failed event
|
||||
"""
|
||||
# Update domain model
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
|
||||
if self.error_handler:
|
||||
result = self.error_handler.handle_node_failure(event)
|
||||
result = self._error_handler.handle_node_failure(event)
|
||||
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Abort execution
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
if result:
|
||||
# Process the resulting event (retry, exception, etc.)
|
||||
self.handle_event(result)
|
||||
else:
|
||||
# Without error handler, just fail
|
||||
self.graph_execution.fail(RuntimeError(event.error))
|
||||
if self.event_collector:
|
||||
self.event_collector.collect(event)
|
||||
if self.execution_tracker:
|
||||
self.execution_tracker.remove(event.node_id)
|
||||
# Abort execution
|
||||
self._graph_execution.fail(RuntimeError(event.error))
|
||||
self._event_collector.collect(event)
|
||||
self._execution_tracker.remove(event.node_id)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
"""
|
||||
|
|
@ -267,7 +242,7 @@ class EventHandlerRegistry:
|
|||
event: The node exception event
|
||||
"""
|
||||
# Node continues via fail-branch, so it's technically "succeeded"
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
|
|
@ -277,7 +252,7 @@ class EventHandlerRegistry:
|
|||
Args:
|
||||
event: The node retry event
|
||||
"""
|
||||
node_execution = self.graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.increment_retry()
|
||||
|
||||
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
|
|
@ -288,16 +263,16 @@ class EventHandlerRegistry:
|
|||
event: The node succeeded event containing outputs
|
||||
"""
|
||||
for variable_name, variable_value in event.node_run_result.outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
|
||||
|
||||
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""Update response outputs for response nodes."""
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self.graph_runtime_state.outputs.get("answer", "")
|
||||
existing = self._graph_runtime_state.outputs.get("answer", "")
|
||||
if existing:
|
||||
self.graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
else:
|
||||
self.graph_runtime_state.outputs["answer"] = value
|
||||
self._graph_runtime_state.outputs["answer"] = value
|
||||
else:
|
||||
self.graph_runtime_state.outputs[key] = value
|
||||
self._graph_runtime_state.outputs[key] = value
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import contextvars
|
|||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
|
|
@ -20,6 +20,7 @@ from core.workflow.enums import NodeExecutionType
|
|||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
|
|
@ -44,6 +45,7 @@ from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, Wo
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngine:
|
||||
"""
|
||||
Queue-based graph execution engine.
|
||||
|
|
@ -62,7 +64,7 @@ class GraphEngine:
|
|||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: Mapping[str, object],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
|
|
@ -103,7 +105,7 @@ class GraphEngine:
|
|||
|
||||
# Initialize queues
|
||||
self.ready_queue: queue.Queue[str] = queue.Queue()
|
||||
self.event_queue: queue.Queue = queue.Queue()
|
||||
self.event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# Initialize subsystems
|
||||
self._initialize_subsystems()
|
||||
|
|
@ -185,7 +187,7 @@ class GraphEngine:
|
|||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
command_processor=self.command_processor,
|
||||
worker_pool=self.worker_pool,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
self.dispatcher = Dispatcher(
|
||||
|
|
@ -209,7 +211,7 @@ class GraphEngine:
|
|||
def _setup_worker_management(self) -> None:
|
||||
"""Initialize worker management subsystem."""
|
||||
# Capture context for workers
|
||||
flask_app: Optional[Flask] = None
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
|
|
@ -218,8 +220,8 @@ class GraphEngine:
|
|||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker management components
|
||||
self.activity_tracker = ActivityTracker()
|
||||
self.dynamic_scaler = DynamicScaler(
|
||||
self._activity_tracker = ActivityTracker()
|
||||
self._dynamic_scaler = DynamicScaler(
|
||||
min_workers=(self._min_workers if self._min_workers is not None else dify_config.GRAPH_ENGINE_MIN_WORKERS),
|
||||
max_workers=(self._max_workers if self._max_workers is not None else dify_config.GRAPH_ENGINE_MAX_WORKERS),
|
||||
scale_up_threshold=(
|
||||
|
|
@ -233,15 +235,15 @@ class GraphEngine:
|
|||
else dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
),
|
||||
)
|
||||
self.worker_factory = WorkerFactory(flask_app, context_vars)
|
||||
self._worker_factory = WorkerFactory(flask_app, context_vars)
|
||||
|
||||
self.worker_pool = WorkerPool(
|
||||
self._worker_pool = WorkerPool(
|
||||
ready_queue=self.ready_queue,
|
||||
event_queue=self.event_queue,
|
||||
graph=self.graph,
|
||||
worker_factory=self.worker_factory,
|
||||
dynamic_scaler=self.dynamic_scaler,
|
||||
activity_tracker=self.activity_tracker,
|
||||
worker_factory=self._worker_factory,
|
||||
dynamic_scaler=self._dynamic_scaler,
|
||||
activity_tracker=self._activity_tracker,
|
||||
)
|
||||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
|
|
@ -319,10 +321,10 @@ class GraphEngine:
|
|||
def _start_execution(self) -> None:
|
||||
"""Start execution subsystems."""
|
||||
# Calculate initial worker count
|
||||
initial_workers = self.dynamic_scaler.calculate_initial_workers(self.graph)
|
||||
initial_workers = self._dynamic_scaler.calculate_initial_workers(self.graph)
|
||||
|
||||
# Start worker pool
|
||||
self.worker_pool.start(initial_workers)
|
||||
self._worker_pool.start(initial_workers)
|
||||
|
||||
# Register response nodes
|
||||
for node in self.graph.nodes.values():
|
||||
|
|
@ -340,7 +342,7 @@ class GraphEngine:
|
|||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self.dispatcher.stop()
|
||||
self.worker_pool.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
|
|
|
|||
|
|
@ -2,15 +2,18 @@
|
|||
Branch node handling for graph traversal.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
|
||||
|
||||
from ..state_management import EdgeStateManager
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class BranchHandler:
|
||||
"""
|
||||
Handles branch node logic during graph traversal.
|
||||
|
|
@ -40,7 +43,9 @@ class BranchHandler:
|
|||
self.skip_propagator = skip_propagator
|
||||
self.edge_state_manager = edge_state_manager
|
||||
|
||||
def handle_branch_completion(self, node_id: str, selected_handle: Optional[str]) -> tuple[list[str], list]:
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
|
|
@ -58,10 +63,10 @@ class BranchHandler:
|
|||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
_, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self.skip_propagator.skip_branch_paths(node_id, unselected_edges)
|
||||
self.skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.edge_processor.process_node_success(node_id, selected_handle)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,18 @@
|
|||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
|
|
@ -38,7 +43,9 @@ class EdgeProcessor:
|
|||
self.node_state_manager = node_state_manager
|
||||
self.response_coordinator = response_coordinator
|
||||
|
||||
def process_node_success(self, node_id: str, selected_handle: str | None = None) -> tuple[list[str], list]:
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
|
|
@ -56,7 +63,7 @@ class EdgeProcessor:
|
|||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[list[str], list]:
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
|
|
@ -66,8 +73,8 @@ class EdgeProcessor:
|
|||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
|
|
@ -77,7 +84,9 @@ class EdgeProcessor:
|
|||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(self, node_id: str, selected_handle: str | None) -> tuple[list[str], list]:
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
|
|
@ -94,8 +103,8 @@ class EdgeProcessor:
|
|||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes = []
|
||||
all_streaming_events = []
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self.edge_state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
|
@ -112,7 +121,7 @@ class EdgeProcessor:
|
|||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[list[str], list]:
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
|
|
@ -129,11 +138,11 @@ class EdgeProcessor:
|
|||
streaming_events = self.response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes = []
|
||||
ready_nodes: list[str] = []
|
||||
if self.node_state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, list(streaming_events)
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,10 +2,13 @@
|
|||
Node readiness checking for execution.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeReadinessChecker:
|
||||
"""
|
||||
Checks if nodes are ready for execution based on their dependencies.
|
||||
|
|
@ -71,7 +74,7 @@ class NodeReadinessChecker:
|
|||
Returns:
|
||||
List of node IDs that are now ready
|
||||
"""
|
||||
ready_nodes = []
|
||||
ready_nodes: list[str] = []
|
||||
outgoing_edges = self.graph.get_outgoing_edges(from_node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
|
|
|
|||
|
|
@ -2,11 +2,15 @@
|
|||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
|
||||
|
||||
@final
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
|
|
@ -57,9 +61,8 @@ class SkipPropagator:
|
|||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Check if node is ready and enqueue if so
|
||||
if self.node_state_manager.is_node_ready(downstream_node_id):
|
||||
self.node_state_manager.enqueue_node(downstream_node_id)
|
||||
# Enqueue node
|
||||
self.node_state_manager.enqueue_node(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
|
|
@ -83,12 +86,11 @@ class SkipPropagator:
|
|||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, node_id: str, unselected_edges: list) -> None:
|
||||
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ intercept and respond to GraphEngine events.
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
|
|
@ -28,8 +27,8 @@ class Layer(ABC):
|
|||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
self.command_channel: Optional[CommandChannel] = None
|
||||
self.graph_runtime_state: GraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
|
|
@ -73,7 +72,7 @@ class Layer(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ graph execution for debugging purposes.
|
|||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any, final
|
||||
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
|
|
@ -34,6 +34,7 @@ from core.workflow.graph_events import (
|
|||
from .base import Layer
|
||||
|
||||
|
||||
@final
|
||||
class DebugLoggingLayer(Layer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
|
|
@ -221,7 +222,7 @@ class DebugLoggingLayer(Layer):
|
|||
# Log unknown events at debug level
|
||||
self.logger.debug("Event: %s", event_class)
|
||||
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Log graph execution end with summary statistics."""
|
||||
self.logger.info("=" * 80)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ When limits are exceeded, the layer automatically aborts execution.
|
|||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import Layer
|
||||
|
|
@ -29,6 +29,7 @@ class LimitType(Enum):
|
|||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(Layer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
|
@ -53,7 +54,7 @@ class ExecutionLimitsLayer(Layer):
|
|||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: Optional[float] = None
|
||||
self.start_time: float | None = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -94,7 +95,7 @@ class ExecutionLimitsLayer(Layer):
|
|||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
def on_graph_end(self, error: Optional[Exception]) -> None:
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@ using the new Redis command channel, without requiring user permission checks.
|
|||
Supports stop, pause, and resume operations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngineManager:
|
||||
"""
|
||||
Manager for sending control commands to GraphEngine instances.
|
||||
|
|
@ -23,7 +24,7 @@ class GraphEngineManager:
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: Optional[str] = None) -> None:
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ import logging
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.workflow.graph_events.base import GraphNodeEventBase
|
||||
|
||||
from ..event_management import EventCollector, EventEmitter
|
||||
from .execution_coordinator import ExecutionCoordinator
|
||||
|
|
@ -17,6 +19,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class Dispatcher:
|
||||
"""
|
||||
Main dispatcher that processes events from the event queue.
|
||||
|
|
@ -27,12 +30,12 @@ class Dispatcher:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
max_execution_time: int,
|
||||
event_emitter: Optional[EventEmitter] = None,
|
||||
event_emitter: EventEmitter | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dispatcher.
|
||||
|
|
@ -52,9 +55,9 @@ class Dispatcher:
|
|||
self.max_execution_time = max_execution_time
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: Optional[float] = None
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the dispatcher thread."""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
|
|
@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
|||
from ..event_management import EventHandlerRegistry
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
"""
|
||||
Coordinates overall execution flow between subsystems.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ thread-safe storage for node outputs.
|
|||
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Union, final
|
||||
|
||||
from core.variables import Segment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
|
@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
|||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class OutputRegistry:
|
||||
"""
|
||||
Thread-safe registry for storing and retrieving node outputs.
|
||||
|
|
@ -47,7 +48,7 @@ class OutputRegistry:
|
|||
with self._lock:
|
||||
self._scalars.add(selector, value)
|
||||
|
||||
def get_scalar(self, selector: Sequence[str]) -> Optional["Segment"]:
|
||||
def get_scalar(self, selector: Sequence[str]) -> "Segment | None":
|
||||
"""
|
||||
Get a scalar value for the given selector.
|
||||
|
||||
|
|
@ -81,7 +82,7 @@ class OutputRegistry:
|
|||
except ValueError:
|
||||
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
|
||||
|
||||
def pop_chunk(self, selector: Sequence[str]) -> Optional["NodeRunStreamChunkEvent"]:
|
||||
def pop_chunk(self, selector: Sequence[str]) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@ This module contains the private Stream class used internally by OutputRegistry
|
|||
to manage streaming data chunks.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@final
|
||||
class Stream:
|
||||
"""
|
||||
A stream that holds NodeRunStreamChunkEvent objects and tracks read position.
|
||||
|
|
@ -41,7 +42,7 @@ class Stream:
|
|||
raise ValueError("Cannot append to a closed stream")
|
||||
self.events.append(event)
|
||||
|
||||
def pop_next(self) -> Optional["NodeRunStreamChunkEvent"]:
|
||||
def pop_next(self) -> "NodeRunStreamChunkEvent | None":
|
||||
"""
|
||||
Pop the next unread NodeRunStreamChunkEvent from the stream.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Base error strategy protocol.
|
||||
"""
|
||||
|
||||
from typing import Optional, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
|
|
@ -16,7 +16,7 @@ class ErrorStrategy(Protocol):
|
|||
node execution failures.
|
||||
"""
|
||||
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> Optional[GraphNodeEventBase]:
|
||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
||||
"""
|
||||
Handle a node failure event.
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ import logging
|
|||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from threading import RLock
|
||||
from typing import Optional, TypeAlias
|
||||
from typing import TypeAlias, final
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
|
|
@ -28,6 +28,7 @@ NodeID: TypeAlias = str
|
|||
EdgeID: TypeAlias = str
|
||||
|
||||
|
||||
@final
|
||||
class ResponseStreamCoordinator:
|
||||
"""
|
||||
Manages response streaming sessions without relying on global state.
|
||||
|
|
@ -45,7 +46,7 @@ class ResponseStreamCoordinator:
|
|||
"""
|
||||
self.registry = registry
|
||||
self.graph = graph
|
||||
self.active_session: Optional[ResponseSession] = None
|
||||
self.active_session: ResponseSession | None = None
|
||||
self.waiting_sessions: deque[ResponseSession] = deque()
|
||||
self.lock = RLock()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ Manager for edge states during graph execution.
|
|||
"""
|
||||
|
||||
import threading
|
||||
from typing import TypedDict
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
|
@ -17,6 +18,7 @@ class EdgeStateAnalysis(TypedDict):
|
|||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class EdgeStateManager:
|
||||
"""
|
||||
Manages edge states and transitions during graph execution.
|
||||
|
|
@ -87,7 +89,7 @@ class EdgeStateManager:
|
|||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[list[Edge], list[Edge]]:
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
|
|
@ -100,8 +102,8 @@ class EdgeStateManager:
|
|||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
selected_edges = []
|
||||
unselected_edges = []
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
|
|
|
|||
|
|
@ -3,8 +3,10 @@ Tracker for currently executing nodes.
|
|||
"""
|
||||
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionTracker:
|
||||
"""
|
||||
Tracks nodes that are currently being executed.
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@ Manager for node states during graph execution.
|
|||
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeStateManager:
|
||||
"""
|
||||
Manages node states and the ready queue for execution.
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import threading
|
|||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
|
|
@ -23,6 +23,7 @@ from core.workflow.nodes.base.node import Node
|
|||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
"""
|
||||
Worker thread that executes nodes from the ready queue.
|
||||
|
|
@ -38,10 +39,10 @@ class Worker(threading.Thread):
|
|||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
worker_id: int = 0,
|
||||
flask_app: Optional[Flask] = None,
|
||||
context_vars: Optional[contextvars.Context] = None,
|
||||
on_idle_callback: Optional[Callable[[int], None]] = None,
|
||||
on_active_callback: Optional[Callable[[int], None]] = None,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
on_idle_callback: Callable[[int], None] | None = None,
|
||||
on_active_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ Activity tracker for monitoring worker activity.
|
|||
|
||||
import threading
|
||||
import time
|
||||
from typing import final
|
||||
|
||||
|
||||
@final
|
||||
class ActivityTracker:
|
||||
"""
|
||||
Tracks worker activity for scaling decisions.
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@
|
|||
Dynamic scaler for worker pool sizing.
|
||||
"""
|
||||
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class DynamicScaler:
|
||||
"""
|
||||
Manages dynamic scaling decisions for the worker pool.
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Factory for creating worker instances.
|
|||
import contextvars
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
from typing import final
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
|
@ -14,6 +14,7 @@ from core.workflow.graph import Graph
|
|||
from ..worker import Worker
|
||||
|
||||
|
||||
@final
|
||||
class WorkerFactory:
|
||||
"""
|
||||
Factory for creating worker instances with proper context.
|
||||
|
|
@ -24,7 +25,7 @@ class WorkerFactory:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
flask_app: Optional[Flask],
|
||||
flask_app: Flask | None,
|
||||
context_vars: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -43,8 +44,8 @@ class WorkerFactory:
|
|||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue,
|
||||
graph: Graph,
|
||||
on_idle_callback: Optional[Callable[[int], None]] = None,
|
||||
on_active_callback: Optional[Callable[[int], None]] = None,
|
||||
on_idle_callback: Callable[[int], None] | None = None,
|
||||
on_active_callback: Callable[[int], None] | None = None,
|
||||
) -> Worker:
|
||||
"""
|
||||
Create a new worker instance.
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Worker pool management.
|
|||
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
|
@ -13,6 +14,7 @@ from .dynamic_scaler import DynamicScaler
|
|||
from .worker_factory import WorkerFactory
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
"""
|
||||
Manages a pool of worker threads for executing nodes.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections.abc import Mapping
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
|
@ -11,6 +11,7 @@ from core.workflow.nodes.start.entities import StartNodeData
|
|||
|
||||
class StartNode(Node):
|
||||
node_type = NodeType.START
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class Storage:
|
|||
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
||||
|
||||
return VolcengineTosStorage
|
||||
case StorageType.SUPBASE:
|
||||
case StorageType.SUPABASE:
|
||||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
|
|
|
|||
|
|
@ -14,4 +14,4 @@ class StorageType(StrEnum):
|
|||
S3 = "s3"
|
||||
TENCENT_COS = "tencent-cos"
|
||||
VOLCENGINE_TOS = "volcengine-tos"
|
||||
SUPBASE = "supabase"
|
||||
SUPABASE = "supabase"
|
||||
|
|
|
|||
|
|
@ -137,10 +137,6 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
|||
return cast(Variable, result)
|
||||
|
||||
|
||||
def infer_segment_type_from_value(value: Any, /) -> SegmentType:
|
||||
return build_segment(value).value_type
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
|
||||
# below
|
||||
|
|
|
|||
|
|
@ -301,8 +301,8 @@ class TokenManager:
|
|||
if expiry_minutes is None:
|
||||
raise ValueError(f"Expiry minutes for {token_type} token is not set")
|
||||
token_key = cls._get_token_key(token, token_type)
|
||||
expiry_time = int(expiry_minutes * 60)
|
||||
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
|
||||
expiry_seconds = int(expiry_minutes * 60)
|
||||
redis_client.setex(token_key, expiry_seconds, json.dumps(token_data))
|
||||
|
||||
if account_id:
|
||||
cls._set_current_token_for_account(account_id, token, token_type, expiry_minutes)
|
||||
|
|
@ -336,11 +336,11 @@ class TokenManager:
|
|||
|
||||
@classmethod
|
||||
def _set_current_token_for_account(
|
||||
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float]
|
||||
cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float]
|
||||
):
|
||||
key = cls._get_account_token_key(account_id, token_type)
|
||||
expiry_time = int(expiry_hours * 60 * 60)
|
||||
redis_client.setex(key, expiry_time, token)
|
||||
expiry_seconds = int(expiry_minutes * 60)
|
||||
redis_client.setex(key, expiry_seconds, token)
|
||||
|
||||
@classmethod
|
||||
def _get_account_token_key(cls, account_id: str, token_type: str) -> str:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
|||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -1556,7 +1556,7 @@ class ApiToken(Base):
|
|||
def generate_api_key(prefix, n):
|
||||
while True:
|
||||
result = prefix + generate_string(n)
|
||||
if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
|
||||
if db.session.scalar(select(exists().where(ApiToken.token == result))):
|
||||
continue
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
|||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, orm
|
||||
from sqlalchemy import DateTime, exists, orm, select
|
||||
|
||||
from core.file.constants import maybe_file_object
|
||||
from core.file.models import File
|
||||
|
|
@ -348,12 +348,13 @@ class Workflow(Base):
|
|||
"""
|
||||
from models.tools import WorkflowToolProvider
|
||||
|
||||
return (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
|
||||
.count()
|
||||
> 0
|
||||
stmt = select(
|
||||
exists().where(
|
||||
WorkflowToolProvider.tenant_id == self.tenant_id,
|
||||
WorkflowToolProvider.app_id == self.app_id,
|
||||
)
|
||||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@property
|
||||
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
|
|
@ -952,7 +953,7 @@ def _naive_utc_datetime():
|
|||
|
||||
class WorkflowDraftVariable(Base):
|
||||
"""`WorkflowDraftVariable` record variables and outputs generated during
|
||||
debugging worfklow or chatflow.
|
||||
debugging workflow or chatflow.
|
||||
|
||||
IMPORTANT: This model maintains multiple invariant rules that must be preserved.
|
||||
Do not instantiate this class directly with the constructor.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from collections import Counter
|
|||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -845,10 +845,8 @@ class DatasetService:
|
|||
|
||||
@staticmethod
|
||||
def dataset_use_check(dataset_id) -> bool:
|
||||
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
|
||||
if count > 0:
|
||||
return True
|
||||
return False
|
||||
stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id))
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_permission(dataset, user):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
|||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -190,11 +191,14 @@ class BuiltinToolManageService:
|
|||
# update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# check if the name is already used
|
||||
if (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||
.count()
|
||||
> 0
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
|
|
@ -246,11 +250,14 @@ class BuiltinToolManageService:
|
|||
)
|
||||
else:
|
||||
# check if the name is already used
|
||||
if (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
||||
.count()
|
||||
> 0
|
||||
if session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.name == name,
|
||||
)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"the credential name '{name}' is already used")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import uuid
|
|||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
|
|
@ -83,15 +83,14 @@ class WorkflowService:
|
|||
)
|
||||
|
||||
def is_workflow_exist(self, app_model: App) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
stmt = select(
|
||||
exists().where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.count()
|
||||
) > 0
|
||||
)
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import exists, select
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -22,7 +23,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count()
|
||||
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
|
|
@ -47,7 +48,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||
)
|
||||
|
||||
try:
|
||||
if annotations_count > 0:
|
||||
if annotations_exists:
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete()
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def test_page_result(text, cursor, maxlen, expected):
|
|||
# Tests: get_url
|
||||
# ---------------------------
|
||||
@pytest.fixture
|
||||
def stub_support_types(monkeypatch):
|
||||
def stub_support_types(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Stub supported content types list."""
|
||||
import core.tools.utils.web_reader_tool as mod
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ def stub_support_types(monkeypatch):
|
|||
return mod
|
||||
|
||||
|
||||
def test_get_url_unsupported_content_type(monkeypatch, stub_support_types):
|
||||
def test_get_url_unsupported_content_type(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
# HEAD 200 but content-type not supported and not text/html
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
return FakeResponse(
|
||||
|
|
@ -62,7 +62,7 @@ def test_get_url_unsupported_content_type(monkeypatch, stub_support_types):
|
|||
assert result == "Unsupported content-type [image/png] of URL."
|
||||
|
||||
|
||||
def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types):
|
||||
def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""
|
||||
When content-type is in SUPPORT_URL_CONTENT_TYPES,
|
||||
should call ExtractProcessor.load_from_url and return its text.
|
||||
|
|
@ -88,7 +88,7 @@ def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_
|
|||
assert result == "PDF extracted text"
|
||||
|
||||
|
||||
def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types):
|
||||
def test_get_url_html_flow_with_chardet_and_readability(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""200 + text/html → GET, chardet detects encoding, readability returns article which is templated."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
|
|
@ -121,7 +121,7 @@ def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_suppor
|
|||
assert "Hello world" in out
|
||||
|
||||
|
||||
def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_support_types):
|
||||
def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""If readability returns no text, should return empty string."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
|
|
@ -142,7 +142,7 @@ def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_su
|
|||
assert out == ""
|
||||
|
||||
|
||||
def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types):
|
||||
def test_get_url_403_cloudscraper_fallback(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
|
|
@ -175,7 +175,7 @@ def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types):
|
|||
assert "X" in out
|
||||
|
||||
|
||||
def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types):
|
||||
def test_get_url_head_non_200_returns_status(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""HEAD returns non-200 and non-403 → should directly return code message."""
|
||||
|
||||
def fake_head(url, headers=None, follow_redirects=True, timeout=None):
|
||||
|
|
@ -189,7 +189,7 @@ def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types):
|
|||
assert out == "URL returned status code 500."
|
||||
|
||||
|
||||
def test_get_url_content_disposition_filename_detection(monkeypatch, stub_support_types):
|
||||
def test_get_url_content_disposition_filename_detection(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""
|
||||
If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type,
|
||||
it should route to ExtractProcessor.load_from_url.
|
||||
|
|
@ -213,7 +213,7 @@ def test_get_url_content_disposition_filename_detection(monkeypatch, stub_suppor
|
|||
assert out == "From ExtractProcessor via filename"
|
||||
|
||||
|
||||
def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_support_types):
|
||||
def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch: pytest.MonkeyPatch, stub_support_types):
|
||||
"""
|
||||
If chardet returns an encoding but content.decode raises, should fallback to response.text.
|
||||
"""
|
||||
|
|
@ -250,7 +250,7 @@ def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_supp
|
|||
# ---------------------------
|
||||
|
||||
|
||||
def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch):
|
||||
def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch: pytest.MonkeyPatch):
|
||||
# stub readabilipy.simple_json_from_html_string
|
||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||
return {
|
||||
|
|
@ -271,7 +271,7 @@ def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch):
|
|||
assert article.text[0]["text"] == "world"
|
||||
|
||||
|
||||
def test_extract_using_readabilipy_defaults_when_missing(monkeypatch):
|
||||
def test_extract_using_readabilipy_defaults_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
def fake_simple_json_from_html_string(html, use_readability=True):
|
||||
return {} # all missing
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from core.tools.errors import ToolInvokeError
|
|||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
||||
|
||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
|
||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||
the `data` element.
|
||||
|
|
@ -40,7 +40,7 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
|||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||
)
|
||||
monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
with pytest.raises(ToolInvokeError) as exc_info:
|
||||
# WorkflowTool always returns a generator, so we need to iterate to
|
||||
|
|
|
|||
|
|
@ -0,0 +1,281 @@
|
|||
"""Unit tests for Graph class methods."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.graph.edge import Edge
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node:
|
||||
"""Create a mock node for testing."""
|
||||
node = Mock(spec=Node)
|
||||
node.id = node_id
|
||||
node.execution_type = execution_type
|
||||
node.state = state
|
||||
node.node_type = NodeType.START
|
||||
return node
|
||||
|
||||
|
||||
class TestMarkInactiveRootBranches:
|
||||
"""Test cases for _mark_inactive_root_branches method."""
|
||||
|
||||
def test_single_root_no_marking(self):
|
||||
"""Test that single root graph doesn't mark anything as skipped."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"]}
|
||||
out_edges = {"root1": ["edge1"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_multiple_roots_mark_inactive(self):
|
||||
"""Test marking inactive root branches with multiple root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "root2": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
|
||||
def test_shared_downstream_node(self):
|
||||
"""Test that shared downstream nodes are not skipped if at least one path is active."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"shared": ["edge3", "edge4"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"child1": ["edge3"],
|
||||
"child2": ["edge4"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
|
||||
def test_deep_branch_marking(self):
|
||||
"""Test marking deep branches with multiple levels."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE),
|
||||
"level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE),
|
||||
"level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE),
|
||||
"level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE),
|
||||
"level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"level1_a": ["edge1"],
|
||||
"level1_b": ["edge2"],
|
||||
"level2_a": ["edge3"],
|
||||
"level2_b": ["edge4"],
|
||||
"level3": ["edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"level1_a": ["edge3"],
|
||||
"level1_b": ["edge4"],
|
||||
"level2_b": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["level1_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level1_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level2_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level2_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
|
||||
def test_non_root_execution_type(self):
|
||||
"""Test that nodes with non-ROOT execution type are not treated as root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "non_root": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_empty_graph(self):
|
||||
"""Test handling of empty graph structures."""
|
||||
nodes = {}
|
||||
edges = {}
|
||||
in_edges = {}
|
||||
out_edges = {}
|
||||
|
||||
# Should not raise any errors
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent")
|
||||
|
||||
def test_three_roots_mark_two_inactive(self):
|
||||
"""Test with three root nodes where two should be marked inactive."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"child3": ["edge3"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2")
|
||||
|
||||
assert nodes["root1"].state == NodeState.SKIPPED
|
||||
assert nodes["root2"].state == NodeState.UNKNOWN # Active root
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.SKIPPED
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert nodes["child3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.SKIPPED
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
|
||||
def test_convergent_paths(self):
|
||||
"""Test convergent paths where multiple inactive branches lead to same node."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE),
|
||||
"mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE),
|
||||
"convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"mid1": ["edge1"],
|
||||
"mid2": ["edge2"],
|
||||
"convergent": ["edge3", "edge4", "edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
"mid1": ["edge4"],
|
||||
"mid2": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["mid1"].state == NodeState.UNKNOWN
|
||||
assert nodes["mid2"].state == NodeState.SKIPPED
|
||||
assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
assert edges["edge4"].state == NodeState.UNKNOWN
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
|
|
@ -21,7 +21,6 @@ from .test_mock_config import MockConfigBuilder
|
|||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
class TestComplexBranchWorkflow:
|
||||
"""Test suite for complex branch workflow with parallel execution."""
|
||||
|
||||
|
|
@ -30,6 +29,7 @@ class TestComplexBranchWorkflow:
|
|||
self.runner = TableTestRunner()
|
||||
self.fixture_path = "test_complex_branch"
|
||||
|
||||
@pytest.mark.skip(reason="output in this workflow can be random")
|
||||
def test_hello_branch_with_llm(self):
|
||||
"""
|
||||
Test when query contains 'hello' - should trigger true branch.
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ This module provides a robust table-driven testing framework with support for:
|
|||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
|
@ -34,7 +34,11 @@ from core.workflow.entities.graph_init_params import GraphInitParams
|
|||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
|
@ -57,7 +61,7 @@ class WorkflowTestCase:
|
|||
timeout: float = 30.0
|
||||
mock_config: Optional[MockConfig] = None
|
||||
use_auto_mock: bool = False
|
||||
expected_event_sequence: Optional[list[type[GraphEngineEvent]]] = None
|
||||
expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
skip: bool = False
|
||||
skip_reason: str = ""
|
||||
|
|
|
|||
|
|
@ -9,13 +9,6 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ
|
|||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def mock_template_transform_run(self):
|
||||
"""Mock the TemplateTransformNode._run() method to return results based on node title."""
|
||||
title = self._node_data.title
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title})
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
class TestVariableAggregator:
|
||||
"""Test cases for the variable aggregator workflow."""
|
||||
|
||||
|
|
@ -37,6 +30,12 @@ class TestVariableAggregator:
|
|||
description: str,
|
||||
) -> None:
|
||||
"""Test all four combinations of switch1 and switch2."""
|
||||
|
||||
def mock_template_transform_run(self):
|
||||
"""Mock the TemplateTransformNode._run() method to return results based on node title."""
|
||||
title = self._node_data.title
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title})
|
||||
|
||||
with patch.object(
|
||||
TemplateTransformNode,
|
||||
"_run",
|
||||
|
|
|
|||
|
|
@ -1,353 +0,0 @@
|
|||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileVariable, FileVariable
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNode,
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - "
|
||||
"needs rewrite for new architecture"
|
||||
)
|
||||
def test_http_request_node_binary_file(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="binary",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
value="",
|
||||
file=["1111", "file"],
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
lambda *args, **kwargs: httpx.Response(200, content=kwargs["content"]),
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - "
|
||||
"needs rewrite for new architecture"
|
||||
)
|
||||
def test_http_request_node_form_with_file(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/post",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="form-data",
|
||||
data=[
|
||||
BodyData(
|
||||
key="file",
|
||||
type="file",
|
||||
file=["1111", "file"],
|
||||
),
|
||||
BodyData(
|
||||
key="name",
|
||||
type="text",
|
||||
value="test",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(
|
||||
["1111", "file"],
|
||||
FileVariable(
|
||||
name="file",
|
||||
value=File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1111",
|
||||
storage_key="",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
)
|
||||
|
||||
def attr_checker(*args, **kwargs):
|
||||
assert kwargs["data"] == {"name": "test"}
|
||||
assert kwargs["files"] == [("file", (None, b"test", "application/octet-stream"))]
|
||||
return httpx.Response(200, content=b"")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
attr_checker,
|
||||
)
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == ""
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="HTTP request tests use old Graph constructor incompatible with new queue-based engine - "
|
||||
"needs rewrite for new architecture"
|
||||
)
|
||||
def test_http_request_node_form_with_multiple_files(monkeypatch):
|
||||
data = HttpRequestNodeData(
|
||||
title="test",
|
||||
method="post",
|
||||
url="http://example.org/upload",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="form-data",
|
||||
data=[
|
||||
BodyData(
|
||||
key="files",
|
||||
type="file",
|
||||
file=["1111", "files"],
|
||||
),
|
||||
BodyData(
|
||||
key="name",
|
||||
type="text",
|
||||
value="test",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
files = [
|
||||
File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file1",
|
||||
filename="image1.jpg",
|
||||
mime_type="image/jpeg",
|
||||
storage_key="",
|
||||
),
|
||||
File(
|
||||
tenant_id="1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file2",
|
||||
filename="document.pdf",
|
||||
mime_type="application/pdf",
|
||||
storage_key="",
|
||||
),
|
||||
]
|
||||
|
||||
variable_pool.add(
|
||||
["1111", "files"],
|
||||
ArrayFileVariable(
|
||||
name="files",
|
||||
value=files,
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",
|
||||
)
|
||||
|
||||
def attr_checker(*args, **kwargs):
|
||||
assert kwargs["data"] == {"name": "test"}
|
||||
|
||||
assert len(kwargs["files"]) == 2
|
||||
assert kwargs["files"][0][0] == "files"
|
||||
assert kwargs["files"][1][0] == "files"
|
||||
|
||||
file_tuples = [f[1] for f in kwargs["files"]]
|
||||
file_contents = [f[1] for f in file_tuples]
|
||||
file_types = [f[2] for f in file_tuples]
|
||||
|
||||
assert b"test_image_data" in file_contents
|
||||
assert b"test_pdf_data" in file_contents
|
||||
assert "image/jpeg" in file_types
|
||||
assert "application/pdf" in file_types
|
||||
|
||||
return httpx.Response(200, content=b'{"status":"success"}')
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.helper.ssrf_proxy.post",
|
||||
attr_checker,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == '{"status":"success"}'
|
||||
print(result.outputs["body"])
|
||||
|
|
@ -1,909 +0,0 @@
|
|||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables.segments import ArrayAnySegment, ArrayStringSegment
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
|
||||
)
|
||||
def test_run():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
|
||||
"id": "answer-4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"iterator_selector": "dify"},
|
||||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
||||
count = 0
|
||||
for item in result:
|
||||
# print(type(item), item)
|
||||
count += 1
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
|
||||
assert count == 20
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
|
||||
)
|
||||
def test_run_parallel():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-2-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt-2",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "tt-2-source-if-else-target",
|
||||
"source": "tt-2",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 321",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt-2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
|
||||
"id": "answer-4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"iterator_selector": "dify"},
|
||||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
||||
count = 0
|
||||
for item in result:
|
||||
count += 1
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
|
||||
assert count == 32
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
|
||||
)
|
||||
def test_iteration_run_in_parallel_mode():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-2-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt-2",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "tt-2-source-if-else-target",
|
||||
"source": "tt-2",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 321",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt-2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
|
||||
"id": "answer-4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
parallel_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
parallel_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=parallel_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
parallel_iteration_node.init_node_data(parallel_node_config["data"])
|
||||
sequential_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
sequential_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=sequential_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
sequential_iteration_node.init_node_data(sequential_node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"iterator_selector": "dify"},
|
||||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
parallel_result = parallel_iteration_node._run()
|
||||
sequential_result = sequential_iteration_node._run()
|
||||
assert parallel_iteration_node._node_data.parallel_nums == 10
|
||||
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
count = 0
|
||||
parallel_arr = []
|
||||
sequential_arr = []
|
||||
for item in parallel_result:
|
||||
count += 1
|
||||
parallel_arr.append(item)
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
assert count == 32
|
||||
|
||||
for item in sequential_result:
|
||||
sequential_arr.append(item)
|
||||
count += 1
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs == {"output": ArrayStringSegment(value=["dify 123", "dify 123"])}
|
||||
assert count == 64
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Iteration nodes are part of Phase 3 container support - not yet implemented in new queue-based engine"
|
||||
)
|
||||
def test_iteration_run_error_handle():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "iteration-start",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "tt2",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt2", "output"],
|
||||
"output_type": "array[string]",
|
||||
"start_node_id": "if-else",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1.split(arg2) }}",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [
|
||||
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
|
||||
{"value_selector": ["iteration-1", "index"], "variable": "arg2"},
|
||||
],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }}",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [
|
||||
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
|
||||
],
|
||||
},
|
||||
"id": "tt2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "1",
|
||||
"variable_selector": ["iteration-1", "item"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="1",
|
||||
files=[],
|
||||
query="dify",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
pool.add(["pe", "list_output"], ["1", "1"])
|
||||
error_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=error_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(error_node_config["data"])
|
||||
# execute continue on error node
|
||||
result = iteration_node._run()
|
||||
result_arr = []
|
||||
count = 0
|
||||
for item in result:
|
||||
result_arr.append(item)
|
||||
count += 1
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[None, None])}
|
||||
|
||||
assert count == 14
|
||||
# execute remove abnormal output
|
||||
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
result = iteration_node._run()
|
||||
count = 0
|
||||
for item in result:
|
||||
count += 1
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs == {"output": ArrayAnySegment(value=[])}
|
||||
assert count == 14
|
||||
|
|
@ -1,624 +0,0 @@
|
|||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class ContinueOnErrorTestHelper:
|
||||
@staticmethod
|
||||
def get_code_node(
|
||||
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
|
||||
):
|
||||
"""Helper method to create a code node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"outputs": {"result": {"type": "number"}},
|
||||
"error_strategy": error_strategy,
|
||||
"title": "code",
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
||||
"type": "code",
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_http_node(
|
||||
error_strategy: str = "fail-branch",
|
||||
default_value: dict | None = None,
|
||||
authorization_success: bool = False,
|
||||
retry_config: dict = {},
|
||||
):
|
||||
"""Helper method to create a http node configuration"""
|
||||
authorization = (
|
||||
{
|
||||
"type": "api-key",
|
||||
"config": {
|
||||
"type": "basic",
|
||||
"api_key": "ak-xxx",
|
||||
"header": "api-key",
|
||||
},
|
||||
}
|
||||
if authorization_success
|
||||
else {
|
||||
"type": "api-key",
|
||||
# missing config field
|
||||
}
|
||||
)
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "http://example.com",
|
||||
"authorization": authorization,
|
||||
"headers": "X-Header:123",
|
||||
"params": "A:b",
|
||||
"body": None,
|
||||
"type": "http-request",
|
||||
"error_strategy": error_strategy,
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
"""Helper method to create a http node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "HTTP Request",
|
||||
"desc": "",
|
||||
"variables": [],
|
||||
"method": "get",
|
||||
"url": "https://api.github.com/issues",
|
||||
"authorization": {"type": "no-auth", "config": None},
|
||||
"headers": "",
|
||||
"params": "",
|
||||
"body": {"type": "none", "data": []},
|
||||
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
|
||||
"error_strategy": error_strategy,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
"""Helper method to create a tool node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "maths",
|
||||
"provider_type": "builtin",
|
||||
"provider_name": "maths",
|
||||
"tool_name": "eval_expression",
|
||||
"tool_label": "eval_expression",
|
||||
"tool_configurations": {},
|
||||
"tool_parameters": {
|
||||
"expression": {
|
||||
"type": "variable",
|
||||
"value": ["1", "123", "args1"],
|
||||
}
|
||||
},
|
||||
"type": "tool",
|
||||
"error_strategy": error_strategy,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node.node_data.default_value = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
"""Helper method to create a llm node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
"error_strategy": error_strategy,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
node["data"]["default_value"] = default_value
|
||||
return node
|
||||
|
||||
@staticmethod
|
||||
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
|
||||
"""Helper method to create a graph engine instance for testing"""
|
||||
# Create graph initialization parameters
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="aaa",
|
||||
files=[],
|
||||
query="clear",
|
||||
conversation_id="abababa",
|
||||
),
|
||||
user_inputs=user_inputs or {"uid": "takato"},
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
node_factory = DifyNodeFactory(init_params, graph_runtime_state)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
return GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
FAIL_BRANCH_EDGES = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-true-success-target",
|
||||
"source": "node",
|
||||
"target": "success",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-false-error-target",
|
||||
"source": "node",
|
||||
"target": "error",
|
||||
"sourceHandle": "fail-branch",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_code_default_value_continue_on_error():
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_code_node(
|
||||
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_code_fail_branch_continue_on_error():
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_code_node(error_code),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_http_node_default_value_continue_on_error():
|
||||
"""Test HTTP node with default value error strategy"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
|
||||
for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_http_node_fail_branch_continue_on_error():
|
||||
"""Test HTTP node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
# def test_tool_node_default_value_continue_on_error():
|
||||
# """Test tool node with default value error strategy"""
|
||||
# graph_config = {
|
||||
# "edges": DEFAULT_VALUE_EDGE,
|
||||
# "nodes": [
|
||||
# {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
# {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
# ContinueOnErrorTestHelper.get_tool_node(
|
||||
# "default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
|
||||
# ),
|
||||
# ],
|
||||
# }
|
||||
|
||||
# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
# events = list(graph_engine.run())
|
||||
|
||||
# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
# assert any(
|
||||
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events # noqa: E501
|
||||
# )
|
||||
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
# def test_tool_node_fail_branch_continue_on_error():
|
||||
# """Test HTTP node with fail-branch error strategy"""
|
||||
# graph_config = {
|
||||
# "edges": FAIL_BRANCH_EDGES,
|
||||
# "nodes": [
|
||||
# {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
# {
|
||||
# "data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
|
||||
# "id": "success",
|
||||
# },
|
||||
# {
|
||||
# "data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
|
||||
# "id": "error",
|
||||
# },
|
||||
# ContinueOnErrorTestHelper.get_tool_node(),
|
||||
# ],
|
||||
# }
|
||||
|
||||
# graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
# events = list(graph_engine.run())
|
||||
|
||||
# assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
# assert any(
|
||||
# isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events # noqa: E501
|
||||
# )
|
||||
# assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_llm_node_default_value_continue_on_error():
|
||||
"""Test LLM node with default value error strategy"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_llm_node(
|
||||
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_llm_node_fail_branch_continue_on_error():
|
||||
"""Test LLM node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_llm_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_status_code_error_http_node_fail_branch_continue_on_error():
|
||||
"""Test HTTP node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(
|
||||
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
|
||||
)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_variable_pool_error_type_variable():
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
list(graph_engine.run())
|
||||
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
|
||||
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
|
||||
assert error_message != None
|
||||
assert error_type.value == "HTTPResponseCodeError"
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_no_node_in_fail_branch_continue_on_error():
|
||||
"""Test HTTP node with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES[:-1],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
|
||||
ContinueOnErrorTestHelper.get_http_node(),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
|
||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Continue-on-error functionality is part of Phase 2 enhanced error handling - "
|
||||
"not fully implemented in MVP of queue-based engine"
|
||||
)
|
||||
def test_stream_output_with_fail_branch_continue_on_error():
|
||||
"""Test stream output with fail-branch error strategy"""
|
||||
graph_config = {
|
||||
"edges": FAIL_BRANCH_EDGES,
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||
"id": "success",
|
||||
},
|
||||
{
|
||||
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
|
||||
"id": "error",
|
||||
},
|
||||
ContinueOnErrorTestHelper.get_llm_node(),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
|
||||
def llm_generator(self):
|
||||
contents = ["hi", "bye", "good morning"]
|
||||
|
||||
yield NodeRunStreamChunkEvent(
|
||||
node_id=self.node_id,
|
||||
node_type=self._node_type,
|
||||
selector=[self.node_id, "text"],
|
||||
chunk=contents[0],
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 1,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
with patch.object(LLMNode, "_run", new=llm_generator):
|
||||
events = list(graph_engine.run())
|
||||
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
|
||||
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models import UserFrom
|
||||
|
||||
|
||||
def _create_tool_node():
|
||||
data = ToolNodeData(
|
||||
title="Test Tool",
|
||||
tool_parameters={},
|
||||
provider_id="test_tool",
|
||||
provider_type=ToolProviderType.WORKFLOW,
|
||||
provider_name="test tool",
|
||||
tool_name="test tool",
|
||||
tool_label="test tool",
|
||||
tool_configurations={},
|
||||
plugin_unique_identifier=None,
|
||||
desc="Exception handling test tool",
|
||||
error_strategy=ErrorStrategy.FAIL_BRANCH,
|
||||
version="1",
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
node = ToolNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
class MockToolRuntime:
|
||||
def get_merged_runtime_parameters(self):
|
||||
pass
|
||||
|
||||
|
||||
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from []
|
||||
raise ToolInvokeError("oops")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Tool node test uses old Graph constructor incompatible with new queue-based engine - "
|
||||
"needs rewrite for new architecture"
|
||||
)
|
||||
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that ToolNode can handle ToolInvokeError when transforming
|
||||
messages generated by ToolEngine.generic_invoke.
|
||||
"""
|
||||
tool_node = _create_tool_node()
|
||||
|
||||
# Need to patch ToolManager and ToolEngine so that we don't
|
||||
# have to set up a database.
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_engine.ToolEngine.generic_invoke",
|
||||
lambda *args, **kwargs: mock_message_stream(),
|
||||
)
|
||||
|
||||
streams = list(tool_node._run())
|
||||
assert len(streams) == 1
|
||||
stream = streams[0]
|
||||
assert isinstance(stream, StreamCompletedEvent)
|
||||
result = stream.node_run_result
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "oops" in result.error
|
||||
assert "Failed to invoke tool" in result.error
|
||||
assert result.error_type == "ToolInvokeError"
|
||||
|
|
@ -14,6 +14,22 @@ interface HeaderParams {
|
|||
interface User {
|
||||
}
|
||||
|
||||
interface DifyFileBase {
|
||||
type: "image"
|
||||
}
|
||||
|
||||
export interface DifyRemoteFile extends DifyFileBase {
|
||||
transfer_method: "remote_url"
|
||||
url: string
|
||||
}
|
||||
|
||||
export interface DifyLocalFile extends DifyFileBase {
|
||||
transfer_method: "local_file"
|
||||
upload_file_id: string
|
||||
}
|
||||
|
||||
export type DifyFile = DifyRemoteFile | DifyLocalFile;
|
||||
|
||||
export declare class DifyClient {
|
||||
constructor(apiKey: string, baseUrl?: string);
|
||||
|
||||
|
|
@ -44,7 +60,7 @@ export declare class CompletionClient extends DifyClient {
|
|||
inputs: any,
|
||||
user: User,
|
||||
stream?: boolean,
|
||||
files?: File[] | null
|
||||
files?: DifyFile[] | null
|
||||
): Promise<any>;
|
||||
}
|
||||
|
||||
|
|
@ -55,7 +71,7 @@ export declare class ChatClient extends DifyClient {
|
|||
user: User,
|
||||
stream?: boolean,
|
||||
conversation_id?: string | null,
|
||||
files?: File[] | null
|
||||
files?: DifyFile[] | null
|
||||
): Promise<any>;
|
||||
|
||||
getSuggested(message_id: string, user: User): Promise<any>;
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ export const checkOrSetAccessToken = async (appCode?: string | null) => {
|
|||
[userId || 'DEFAULT']: res.access_token,
|
||||
}
|
||||
localStorage.setItem('token', JSON.stringify(accessTokenJson))
|
||||
localStorage.removeItem(CONVERSATION_ID_INFO)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import type { FC, PropsWithChildren } from 'react'
|
|||
import { useEffect } from 'react'
|
||||
import { useState } from 'react'
|
||||
import { create } from 'zustand'
|
||||
import { useGlobalPublicStore } from './global-public-context'
|
||||
|
||||
type WebAppStore = {
|
||||
shareCode: string | null
|
||||
|
|
@ -56,6 +57,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => {
|
|||
}
|
||||
|
||||
const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
||||
const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
|
||||
const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode)
|
||||
const updateShareCode = useWebAppStore(state => state.updateShareCode)
|
||||
const pathname = usePathname()
|
||||
|
|
@ -69,7 +71,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
|||
}, [shareCode, updateShareCode])
|
||||
|
||||
const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode)
|
||||
const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(false)
|
||||
const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(true)
|
||||
|
||||
useEffect(() => {
|
||||
if (accessModeResult?.accessMode) {
|
||||
|
|
@ -86,7 +88,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
|||
}
|
||||
}, [accessModeResult, updateWebAppAccessMode, shareCode])
|
||||
|
||||
if (isFetching || isFetchingAccessToken) {
|
||||
if (isGlobalPending || isFetching || isFetchingAccessToken) {
|
||||
return <div className='flex h-full w-full items-center justify-center'>
|
||||
<Loading />
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -430,9 +430,7 @@ export const ssePost = async (
|
|||
.then((res) => {
|
||||
if (!/^[23]\d{2}$/.test(String(res.status))) {
|
||||
if (res.status === 401) {
|
||||
refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
|
||||
ssePost(url, fetchOptions, otherOptions)
|
||||
}).catch(() => {
|
||||
if (isPublicAPI) {
|
||||
res.json().then((data: any) => {
|
||||
if (isPublicAPI) {
|
||||
if (data.code === 'web_app_access_denied')
|
||||
|
|
@ -449,7 +447,14 @@ export const ssePost = async (
|
|||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
else {
|
||||
refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
|
||||
ssePost(url, fetchOptions, otherOptions)
|
||||
}).catch((err) => {
|
||||
console.error(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
else {
|
||||
res.json().then((data) => {
|
||||
|
|
|
|||
|
|
@ -1,20 +1,12 @@
|
|||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { fetchAppInfo, fetchAppMeta, fetchAppParams, getAppAccessModeByAppCode } from './share'
|
||||
|
||||
const NAME_SPACE = 'webapp'
|
||||
|
||||
export const useGetWebAppAccessModeByCode = (code: string | null) => {
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
return useQuery({
|
||||
queryKey: [NAME_SPACE, 'appAccessMode', code],
|
||||
queryFn: () => {
|
||||
if (systemFeatures.webapp_auth.enabled === false) {
|
||||
return {
|
||||
accessMode: AccessMode.PUBLIC,
|
||||
}
|
||||
}
|
||||
if (!code || code.length === 0)
|
||||
return Promise.reject(new Error('App code is required to get access mode'))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue