From 24ebe2f5c6f31c5be430531ca1a02f83af35c51e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 28 Jan 2026 19:57:55 +0800 Subject: [PATCH] refactor(graph_engine): Add a Config class for graph engine. (#31663) Signed-off-by: -LAN- --- api/.importlinter | 1 - api/core/workflow/graph_engine/__init__.py | 3 +- api/core/workflow/graph_engine/config.py | 14 ++++++ .../workflow/graph_engine/graph_engine.py | 19 ++------ .../worker_management/worker_pool.py | 45 +++++++------------ .../nodes/iteration/iteration_node.py | 3 +- api/core/workflow/nodes/loop/loop_node.py | 3 +- api/core/workflow/workflow_entry.py | 8 +++- .../layers/test_layer_initialization.py | 3 +- .../graph_engine/test_command_system.py | 5 ++- ...ditional_streaming_vs_template_workflow.py | 4 +- .../graph_engine/test_graph_engine.py | 6 ++- .../workflow/graph_engine/test_mock_nodes.py | 6 ++- .../test_parallel_streaming_workflow.py | 3 +- .../workflow/graph_engine/test_stop_event.py | 13 +++++- .../graph_engine/test_table_runner.py | 12 ++--- .../graph_engine/test_tool_in_chatflow.py | 3 +- 17 files changed, 89 insertions(+), 62 deletions(-) create mode 100644 api/core/workflow/graph_engine/config.py diff --git a/api/.importlinter b/api/.importlinter index cc7ffc15c8..2b4a3a5bd6 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -105,7 +105,6 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability - core.workflow.graph_engine.worker_management.worker_pool -> configs core.workflow.nodes.agent.agent_node -> core.model_manager core.workflow.nodes.agent.agent_node -> core.provider_manager core.workflow.nodes.agent.agent_node -> core.tools.tool_manager diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py index fe792c71ad..0e1c7dd60a 100644 --- a/api/core/workflow/graph_engine/__init__.py +++ b/api/core/workflow/graph_engine/__init__.py @@ -1,3 +1,4 @@ +from .config import GraphEngineConfig from .graph_engine import GraphEngine -__all__ = ["GraphEngine"] +__all__ = ["GraphEngine", "GraphEngineConfig"] diff --git a/api/core/workflow/graph_engine/config.py b/api/core/workflow/graph_engine/config.py new file mode 100644 index 0000000000..10dbbd7535 --- /dev/null +++ b/api/core/workflow/graph_engine/config.py @@ -0,0 +1,14 @@ +""" +GraphEngine configuration models. +""" + +from pydantic import BaseModel + + +class GraphEngineConfig(BaseModel): + """Configuration for GraphEngine worker pool scaling.""" + + min_workers: int = 1 + max_workers: int = 5 + scale_up_threshold: int = 3 + scale_down_idle_time: float = 5.0 diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index dbb2727c98..0b359a2392 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -37,6 +37,7 @@ from .command_processing import ( PauseCommandHandler, UpdateVariablesCommandHandler, ) +from .config import GraphEngineConfig from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager @@ -70,10 +71,7 @@ class GraphEngine: graph: Graph, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, - min_workers: int | None = None, - max_workers: int | None = None, - scale_up_threshold: int | None = None, - scale_down_idle_time: float | None = None, + config: GraphEngineConfig, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" # stop event @@ -85,18 +83,12 @@ class GraphEngine: self._graph_runtime_state.stop_event = self._stop_event self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel + self._config = config # Graph execution tracks the overall execution state self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) self._graph_execution.workflow_id = workflow_id - # === Worker Management Parameters === - # Parameters for dynamic worker pool scaling - self._min_workers = min_workers - self._max_workers = max_workers - self._scale_up_threshold = scale_up_threshold - self._scale_down_idle_time = scale_down_idle_time - # === Execution Queues === self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue) @@ -167,10 +159,7 @@ class GraphEngine: graph=self._graph, layers=self._layers, execution_context=execution_context, - min_workers=self._min_workers, - max_workers=self._max_workers, - scale_up_threshold=self._scale_up_threshold, - scale_down_idle_time=self._scale_down_idle_time, + config=self._config, stop_event=self._stop_event, ) diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 9ce7d16e93..3bff566ac8 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -10,11 +10,11 @@ import queue import threading from typing import final -from configs import dify_config from core.workflow.context import IExecutionContext from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase +from ..config import GraphEngineConfig from ..layers.base import GraphEngineLayer from ..ready_queue import ReadyQueue from ..worker import Worker @@ -38,11 +38,8 @@ class WorkerPool: graph: Graph, layers: list[GraphEngineLayer], stop_event: threading.Event, + config: GraphEngineConfig, execution_context: IExecutionContext | None = None, - min_workers: int | None = None, - max_workers: int | None = None, - scale_up_threshold: int | None = None, - scale_down_idle_time: float | None = None, ) -> None: """ Initialize the simple worker pool. @@ -52,23 +49,15 @@ class WorkerPool: event_queue: Queue for worker events graph: The workflow graph layers: Graph engine layers for node execution hooks + config: GraphEngine worker pool configuration execution_context: Optional execution context for context preservation - min_workers: Minimum number of workers - max_workers: Maximum number of workers - scale_up_threshold: Queue depth to trigger scale up - scale_down_idle_time: Seconds before scaling down idle workers """ self._ready_queue = ready_queue self._event_queue = event_queue self._graph = graph self._execution_context = execution_context self._layers = layers - - # Scaling parameters with defaults - self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS - self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS - self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD - self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME + self._config = config # Worker management self._workers: list[Worker] = [] @@ -96,18 +85,18 @@ class WorkerPool: if initial_count is None: node_count = len(self._graph.nodes) if node_count < 10: - initial_count = self._min_workers + initial_count = self._config.min_workers elif node_count < 50: - initial_count = min(self._min_workers + 1, self._max_workers) + initial_count = min(self._config.min_workers + 1, self._config.max_workers) else: - initial_count = min(self._min_workers + 2, self._max_workers) + initial_count = min(self._config.min_workers + 2, self._config.max_workers) logger.debug( "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", initial_count, node_count, - self._min_workers, - self._max_workers, + self._config.min_workers, + self._config.max_workers, ) # Create initial workers @@ -176,7 +165,7 @@ class WorkerPool: Returns: True if scaled up, False otherwise """ - if queue_depth > self._scale_up_threshold and current_count < self._max_workers: + if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers: old_count = current_count self._create_worker() @@ -185,7 +174,7 @@ class WorkerPool: old_count, len(self._workers), queue_depth, - self._scale_up_threshold, + self._config.scale_up_threshold, ) return True return False @@ -204,7 +193,7 @@ class WorkerPool: True if scaled down, False otherwise """ # Skip if we're at minimum or have no idle workers - if current_count <= self._min_workers or idle_count == 0: + if current_count <= self._config.min_workers or idle_count == 0: return False # Check if we have excess capacity @@ -222,10 +211,10 @@ class WorkerPool: for worker in self._workers: # Check if worker is idle and has exceeded idle time threshold - if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time: + if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time: # Don't remove if it would leave us unable to handle the queue remaining_workers = current_count - len(workers_to_remove) - 1 - if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2): + if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2): workers_to_remove.append((worker, worker.worker_id)) # Only remove one worker per check to avoid aggressive scaling break @@ -242,7 +231,7 @@ class WorkerPool: old_count, len(self._workers), len(workers_to_remove), - self._scale_down_idle_time, + self._config.scale_down_idle_time, queue_depth, active_count, idle_count - len(workers_to_remove), @@ -286,6 +275,6 @@ class WorkerPool: return { "total_workers": len(self._workers), "queue_depth": self._ready_queue.qsize(), - "min_workers": self._min_workers, - "max_workers": self._max_workers, + "min_workers": self._config.min_workers, + "max_workers": self._config.max_workers, } diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index ced996e7e0..c19182549f 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -591,7 +591,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine + from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.runtime import GraphRuntimeState @@ -640,6 +640,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph=iteration_graph, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs + config=GraphEngineConfig(), ) return graph_engine diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 07d05966cc..84a9c29414 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -416,7 +416,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine + from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.runtime import GraphRuntimeState @@ -452,6 +452,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): graph=loop_graph, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs + config=GraphEngineConfig(), ) return graph_engine diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index b645f29d27..43f15f6fd0 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -14,7 +14,7 @@ from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer from core.workflow.graph_engine.protocols.command_channel import CommandChannel @@ -81,6 +81,12 @@ class WorkflowEntry: graph=graph, graph_runtime_state=graph_runtime_state, command_channel=command_channel, + config=GraphEngineConfig( + min_workers=dify_config.GRAPH_ENGINE_MIN_WORKERS, + max_workers=dify_config.GRAPH_ENGINE_MAX_WORKERS, + scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD, + scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME, + ), ) # Add debug logging layer when in debug mode diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py index d6ba61c50c..f1086c9936 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -2,7 +2,7 @@ from __future__ import annotations import pytest -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.layers.base import ( GraphEngineLayer, @@ -43,6 +43,7 @@ def test_layer_runtime_state_available_after_engine_layer() -> None: graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) layer = LayerForTest() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index d826f7a900..1af5a80a56 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -8,7 +8,7 @@ from core.variables import IntegerVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.entities.commands import ( AbortCommand, @@ -67,6 +67,7 @@ def test_abort_command(): graph=mock_graph, graph_runtime_state=shared_runtime_state, # Use shared instance command_channel=command_channel, + config=GraphEngineConfig(), ) # Send abort command before starting @@ -173,6 +174,7 @@ def test_pause_command(): graph=mock_graph, graph_runtime_state=shared_runtime_state, command_channel=command_channel, + config=GraphEngineConfig(), ) pause_command = PauseCommand(reason="User requested pause") @@ -228,6 +230,7 @@ def test_update_variables_command_updates_pool(): graph=mock_graph, graph_runtime_state=shared_runtime_state, command_channel=command_channel, + config=GraphEngineConfig(), ) update_command = UpdateVariablesCommand( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index 70a772fc4c..ee944c8e3e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -7,7 +7,7 @@ This test validates that: """ from core.workflow.enums import NodeType -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import ( GraphRunSucceededEvent, @@ -44,6 +44,7 @@ def test_streaming_output_with_blocking_equals_one(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Execute the workflow @@ -139,6 +140,7 @@ def test_streaming_output_with_blocking_not_equals_one(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Execute the workflow diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 02f20413e0..5a55d7086e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -11,7 +11,7 @@ from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st from core.workflow.enums import ErrorStrategy -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import ( GraphRunPartialSucceededEvent, @@ -469,6 +469,7 @@ def test_layer_system_basic(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Add debug logging layer @@ -525,6 +526,7 @@ def test_layer_chaining(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Chain multiple layers @@ -572,6 +574,7 @@ def test_layer_error_handling(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Add faulty layer @@ -753,6 +756,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) events = list(engine.run()) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 5937bbfb39..2179ff663b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -566,7 +566,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): # Import dependencies from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine + from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.runtime import GraphRuntimeState @@ -623,6 +623,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): graph=iteration_graph, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs + config=GraphEngineConfig(), ) return graph_engine @@ -641,7 +642,7 @@ class MockLoopNode(MockNodeMixin, LoopNode): # Import dependencies from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine + from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.runtime import GraphRuntimeState @@ -685,6 +686,7 @@ class MockLoopNode(MockNodeMixin, LoopNode): graph=loop_graph, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs + config=GraphEngineConfig(), ) return graph_engine diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index e8cd665107..53c6bc3d60 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -17,7 +17,7 @@ from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import ( GraphRunSucceededEvent, @@ -123,6 +123,7 @@ def test_parallel_streaming_workflow(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Define LLM outputs diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py index ea8d3a977f..0b998034b1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, Mock, patch from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import ( GraphRunStartedEvent, @@ -41,6 +41,7 @@ class TestStopEventPropagation: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Verify stop_event was created @@ -84,6 +85,7 @@ class TestStopEventPropagation: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Set the stop_event before running @@ -131,6 +133,7 @@ class TestStopEventPropagation: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Initially not set @@ -155,6 +158,7 @@ class TestStopEventPropagation: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Verify WorkerPool has the stop_event @@ -174,6 +178,7 @@ class TestStopEventPropagation: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Verify Dispatcher has the stop_event @@ -311,6 +316,7 @@ class TestStopEventIntegration: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Set stop_event before running @@ -360,6 +366,7 @@ class TestStopEventIntegration: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # All nodes should share the same stop_event @@ -385,6 +392,7 @@ class TestStopEventTimeoutBehavior: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) dispatcher = engine._dispatcher @@ -411,6 +419,7 @@ class TestStopEventTimeoutBehavior: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) worker_pool = engine._worker_pool @@ -460,6 +469,7 @@ class TestStopEventResumeBehavior: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Simulate a previous execution that set stop_event @@ -490,6 +500,7 @@ class TestWorkerStopBehavior: graph=mock_graph, graph_runtime_state=runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) # Get the worker pool and check workers diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 10ac1206fb..afa9265fcd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -32,7 +32,7 @@ from core.variables import ( ) from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import ( GraphEngineEvent, @@ -309,10 +309,12 @@ class TableTestRunner: graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), - min_workers=self.graph_engine_min_workers, - max_workers=self.graph_engine_max_workers, - scale_up_threshold=self.graph_engine_scale_up_threshold, - scale_down_idle_time=self.graph_engine_scale_down_idle_time, + config=GraphEngineConfig( + min_workers=self.graph_engine_min_workers, + max_workers=self.graph_engine_max_workers, + scale_up_threshold=self.graph_engine_scale_up_threshold, + scale_down_idle_time=self.graph_engine_scale_down_idle_time, + ), ) # Execute and collect events diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index 34682ff8f9..bfcc6e1a5f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,4 +1,4 @@ -from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine import GraphEngine, GraphEngineConfig from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import ( GraphRunSucceededEvent, @@ -27,6 +27,7 @@ def test_tool_in_chatflow(): graph=graph, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), + config=GraphEngineConfig(), ) events = list(engine.run())