From 6ffa2ebabf353c89c17abe2c0f92921046d728aa Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 9 Sep 2025 22:16:42 +0800 Subject: [PATCH 01/31] feat: improve error handling in graph node creation - Replace ValueError catch with generic Exception - Use logger.exception for automatic traceback logging - Abort on node creation failure instead of continuing --- api/core/workflow/graph/graph.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index dc38d4d2a3..8654ad4ef5 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -188,9 +188,9 @@ class Graph: for node_id, node_config in node_configs_map.items(): try: node_instance = node_factory.create_node(node_config) - except ValueError as e: - logger.warning("Failed to create node instance: %s", str(e)) - continue + except Exception: + logger.exception("Failed to create node instance for node_id %s", node_id) + raise nodes[node_id] = node_instance return nodes From e0e82fbfaa8116f188312171196a703ade4aa96c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:15:36 +0800 Subject: [PATCH 02/31] refactor: extract _run method into smaller focused methods in IterationNode - Extract iterator variable retrieval and validation logic - Separate empty iteration handling - Create dedicated methods for iteration execution and result handling - Improve type hints and use modern Python syntax - Enhance code readability and maintainability --- .../nodes/iteration/iteration_node.py | 229 ++++++++++++------ 1 file changed, 149 insertions(+), 80 deletions(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index f15730d105..6aaa432ca7 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,7 +1,9 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, NewType, cast + +from typing_extensions import TypeIs from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment @@ -23,6 +25,7 @@ from core.workflow.node_events import ( IterationNextEvent, IterationStartedEvent, IterationSucceededEvent, + NodeEventBase, NodeRunResult, StreamCompletedEvent, ) @@ -45,6 +48,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) class IterationNode(Node): """ @@ -58,7 +62,7 @@ class IterationNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -67,7 +71,7 @@ class IterationNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -77,7 +81,7 @@ class IterationNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict[str, object] | None = None): return { "type": "iteration", "config": { @@ -91,40 +95,17 @@ class IterationNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator: - variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) + def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # pyright: ignore[reportIncompatibleMethodOverride] + variable = self._get_iterator_variable() - if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") - - if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): - raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") - - if isinstance(variable, NoneSegment) or len(variable.value) == 0: - # Try our best to preserve the type informat. - if isinstance(variable, ArraySegment): - output = variable.model_copy(update={"value": []}) - else: - output = ArrayAnySegment(value=[]) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - # TODO(QuantumGhost): is it possible to compute the type of `output` - # from graph definition? - outputs={"output": output}, - ) - ) + if self._is_empty_iteration(variable): + yield from self._handle_empty_iteration(variable) return - iterator_list_value = variable.to_object() - - if not isinstance(iterator_list_value, list): - raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - + iterator_list_value = self._validate_and_get_iterator_list(variable) inputs = {"iterator_selector": iterator_list_value} - if not self._node_data.start_node_id: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") + self._validate_start_node() started_at = naive_utc_now() iter_run_map: dict[str, float] = {} @@ -137,62 +118,150 @@ class IterationNode(Node): ) try: - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Update the total tokens from this iteration - self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - yield IterationSucceededEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, + yield from self._execute_iterations( + iterator_list_value=iterator_list_value, + outputs=outputs, + iter_run_map=iter_run_map, ) - # Yield final success event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, - }, - ) + yield from self._handle_iteration_success( + started_at=started_at, + inputs=inputs, + outputs=outputs, + iterator_list_value=iterator_list_value, + iter_run_map=iter_run_map, ) except IterationNodeError as e: - yield IterationFailedEvent( - start_at=started_at, + yield from self._handle_iteration_failure( + started_at=started_at, inputs=inputs, + outputs=outputs, + iterator_list_value=iterator_list_value, + iter_run_map=iter_run_map, + error=e, + ) + + def _get_iterator_variable(self) -> ArraySegment | NoneSegment: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) + + if not variable: + raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") + + if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): + raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") + + return variable + + def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]: + return isinstance(variable, NoneSegment) or len(variable.value) == 0 + + def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]: + # Try our best to preserve the type information. + if isinstance(variable, ArraySegment): + output = variable.model_copy(update={"value": []}) + else: + output = ArrayAnySegment(value=[]) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + # TODO(QuantumGhost): is it possible to compute the type of `output` + # from graph definition? + outputs={"output": output}, + ) + ) + + def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]: + iterator_list_value = variable.to_object() + + if not isinstance(iterator_list_value, list): + raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") + + return cast(list[object], iterator_list_value) + + def _validate_start_node(self) -> None: + if not self._node_data.start_node_id: + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") + + def _execute_iterations( + self, + iterator_list_value: Sequence[object], + outputs: list[Any], + iter_run_map: dict[str, float], + ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: + for index, item in enumerate(iterator_list_value): + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + yield IterationNextEvent(index=index) + + graph_engine = self._create_graph_engine(index, item) + + # Run the iteration + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, + ) + + # Update the total tokens from this iteration + self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens + iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + + def _handle_iteration_success( + self, + started_at: datetime, + inputs: dict[str, Sequence[object]], + outputs: list[Any], + iterator_list_value: Sequence[object], + iter_run_map: dict[str, float], + ) -> Generator[NodeEventBase, None, None]: + yield IterationSucceededEvent( + start_at=started_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, + }, + ) + + # Yield final success event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": outputs}, - steps=len(iterator_list_value), metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, - error=str(e), ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) + ) + + def _handle_iteration_failure( + self, + started_at: datetime, + inputs: dict[str, Sequence[object]], + outputs: list[Any], + iterator_list_value: Sequence[object], + iter_run_map: dict[str, float], + error: IterationNodeError, + ) -> Generator[NodeEventBase, None, None]: + yield IterationFailedEvent( + start_at=started_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, + }, + error=str(error), + ) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(error), ) + ) @classmethod def _extract_variable_selector_to_variable_mapping( @@ -305,9 +374,9 @@ class IterationNode(Node): self, *, variable_pool: VariablePool, - outputs: list, + outputs: list[object], graph_engine: "GraphEngine", - ) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]: + ) -> Generator[GraphNodeEventBase, None, None]: rst = graph_engine.run() # get current iteration index index_variable = variable_pool.get([self._node_id, "index"]) From a23c8fcb1ae6eb11d2960d3d3eb7f250150a9213 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:32:45 +0800 Subject: [PATCH 03/31] refactor: move execution limits from engine core to layer Remove max_execution_time and max_execution_steps from ExecutionContext and GraphEngine since these limits are now handled by ExecutionLimitsLayer. This follows the separation of concerns principle by keeping execution limits as a cross-cutting concern handled by layers rather than embedded in core engine components. Changes: - Remove max_execution_time and max_execution_steps from ExecutionContext - Remove these parameters from GraphEngine.__init__() - Remove max_execution_time from Dispatcher - Update workflow_entry.py to no longer pass these parameters - Update all tests to remove these parameters --- api/core/workflow/graph_engine/domain/execution_context.py | 6 ------ api/core/workflow/graph_engine/graph_engine.py | 5 ----- api/core/workflow/graph_engine/orchestration/dispatcher.py | 3 --- api/core/workflow/nodes/iteration/iteration_node.py | 5 ++--- api/core/workflow/nodes/loop/loop_node.py | 3 --- api/core/workflow/workflow_entry.py | 2 -- .../core/workflow/graph_engine/test_command_system.py | 2 -- .../test_conditional_streaming_vs_template_workflow.py | 4 ---- .../core/workflow/graph_engine/test_graph_engine.py | 6 ------ .../core/workflow/graph_engine/test_mock_nodes.py | 4 ---- .../graph_engine/test_parallel_streaming_workflow.py | 2 -- .../core/workflow/graph_engine/test_table_runner.py | 2 -- .../core/workflow/graph_engine/test_tool_in_chatflow.py | 2 -- 13 files changed, 2 insertions(+), 44 deletions(-) diff --git a/api/core/workflow/graph_engine/domain/execution_context.py b/api/core/workflow/graph_engine/domain/execution_context.py index 0b4116f39d..9bcff0fea7 100644 --- a/api/core/workflow/graph_engine/domain/execution_context.py +++ b/api/core/workflow/graph_engine/domain/execution_context.py @@ -24,14 +24,8 @@ class ExecutionContext: user_from: UserFrom invoke_from: InvokeFrom call_depth: int - max_execution_steps: int - max_execution_time: int def __post_init__(self) -> None: """Validate execution context parameters.""" if self.call_depth < 0: raise ValueError("Call depth must be non-negative") - if self.max_execution_steps <= 0: - raise ValueError("Max execution steps must be positive") - if self.max_execution_time <= 0: - raise ValueError("Max execution time must be positive") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 45f3ada7f5..b6563058bc 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -65,8 +65,6 @@ class GraphEngine: graph: Graph, graph_config: Mapping[str, object], graph_runtime_state: GraphRuntimeState, - max_execution_steps: int, - max_execution_time: int, command_channel: CommandChannel, min_workers: int | None = None, max_workers: int | None = None, @@ -85,8 +83,6 @@ class GraphEngine: user_from=user_from, invoke_from=invoke_from, call_depth=call_depth, - max_execution_steps=max_execution_steps, - max_execution_time=max_execution_time, ) # Graph execution tracks the overall execution state @@ -216,7 +212,6 @@ class GraphEngine: event_handler=self._event_handler_registry, event_collector=self._event_manager, execution_coordinator=self._execution_coordinator, - max_execution_time=self._execution_context.max_execution_time, event_emitter=self._event_manager, ) diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 80f744c941..bb4720a684 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -34,7 +34,6 @@ class Dispatcher: event_handler: "EventHandler", event_collector: EventManager, execution_coordinator: ExecutionCoordinator, - max_execution_time: int, event_emitter: EventManager | None = None, ) -> None: """ @@ -45,14 +44,12 @@ class Dispatcher: event_handler: Event handler registry for processing events event_collector: Event manager for collecting unhandled events execution_coordinator: Coordinator for execution flow - max_execution_time: Maximum execution time in seconds event_emitter: Optional event manager to signal completion """ self._event_queue = event_queue self._event_handler = event_handler self._event_collector = event_collector self._execution_coordinator = execution_coordinator - self._max_execution_time = max_execution_time self._event_emitter = event_emitter self._thread: threading.Thread | None = None diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 6aaa432ca7..56b0421454 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -50,6 +50,7 @@ logger = logging.getLogger(__name__) EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) + class IterationNode(Node): """ Iteration Node. @@ -95,7 +96,7 @@ class IterationNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # pyright: ignore[reportIncompatibleMethodOverride] + def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # pyright: ignore[reportIncompatibleMethodOverride] variable = self._get_iterator_variable() if self._is_empty_iteration(variable): @@ -466,8 +467,6 @@ class IterationNode(Node): graph=iteration_graph, graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, - max_execution_steps=10000, # Use default or config value - max_execution_time=600, # Use default or config value command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 3c5259ea26..ba26322cc3 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -4,7 +4,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, Optional, cast -from configs import dify_config from core.variables import Segment, SegmentType from core.workflow.enums import ( ErrorStrategy, @@ -454,8 +453,6 @@ class LoopNode(Node): graph=loop_graph, graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, - max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 544e4cb5b4..466e537a1a 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -83,8 +83,6 @@ class WorkflowEntry: graph=graph, graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, command_channel=command_channel, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 40b164a0c2..58073ba5c3 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -52,8 +52,6 @@ def test_abort_command(): graph=mock_graph, graph_config={}, graph_runtime_state=shared_runtime_state, # Use shared instance - max_execution_steps=100, - max_execution_time=10, command_channel=command_channel, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index 7ea789af51..2b2e4fe022 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -55,8 +55,6 @@ def test_streaming_output_with_blocking_equals_one(): graph=graph, graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=500, - max_execution_time=30, command_channel=InMemoryChannel(), ) @@ -162,8 +160,6 @@ def test_streaming_output_with_blocking_not_equals_one(): graph=graph, graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=500, - max_execution_time=30, command_channel=InMemoryChannel(), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 11eecb6d77..4aa33bde26 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -470,8 +470,6 @@ def test_layer_system_basic(): graph=graph, graph_config=fixture_data.get("workflow", {}).get("graph", {}), graph_runtime_state=graph_runtime_state, - max_execution_steps=300, - max_execution_time=60, command_channel=InMemoryChannel(), ) @@ -535,8 +533,6 @@ def test_layer_chaining(): graph=graph, graph_config=fixture_data.get("workflow", {}).get("graph", {}), graph_runtime_state=graph_runtime_state, - max_execution_steps=300, - max_execution_time=60, command_channel=InMemoryChannel(), ) @@ -591,8 +587,6 @@ def test_layer_error_handling(): graph=graph, graph_config=fixture_data.get("workflow", {}).get("graph", {}), graph_runtime_state=graph_runtime_state, - max_execution_steps=300, - max_execution_time=60, command_channel=InMemoryChannel(), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 3a8142d857..8229409ffd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -625,8 +625,6 @@ class MockIterationNode(MockNodeMixin, IterationNode): graph=iteration_graph, graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, - max_execution_steps=10000, # Use default or config value - max_execution_time=600, # Use default or config value command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) @@ -695,8 +693,6 @@ class MockLoopNode(MockNodeMixin, LoopNode): graph=loop_graph, graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, - max_execution_steps=10000, # Use default or config value - max_execution_time=600, # Use default or config value command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 581f9a07da..04f0aa7f2e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -128,8 +128,6 @@ def test_parallel_streaming_workflow(): graph=graph, graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=500, - max_execution_time=30, command_channel=InMemoryChannel(), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 744e558e99..4c744e91bd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -388,8 +388,6 @@ class TableTestRunner: graph=graph, graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=500, - max_execution_time=int(test_case.timeout), command_channel=InMemoryChannel(), min_workers=self.graph_engine_min_workers, max_workers=self.graph_engine_max_workers, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index a192eadc82..e227518a8e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -38,8 +38,6 @@ def test_tool_in_chatflow(): graph=graph, graph_config=graph_config, graph_runtime_state=graph_runtime_state, - max_execution_steps=500, - max_execution_time=30, command_channel=InMemoryChannel(), ) From ea5dfe41d53873113e774a9ca8930eedb00924cd Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:36:11 +0800 Subject: [PATCH 04/31] chore: ignore comment Signed-off-by: -LAN- --- api/core/workflow/nodes/iteration/iteration_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 56b0421454..e092536d0a 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -96,7 +96,7 @@ class IterationNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # pyright: ignore[reportIncompatibleMethodOverride] + def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore variable = self._get_iterator_variable() if self._is_empty_iteration(variable): From e060d7c28cf51a0ba510387cfa65e3b49e43bb58 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:49:15 +0800 Subject: [PATCH 05/31] refactor(graph_engine): remove Optional Signed-off-by: -LAN- --- api/core/workflow/graph_events/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/graph_events/base.py b/api/core/workflow/graph_events/base.py index 98ffef7924..3714679201 100644 --- a/api/core/workflow/graph_events/base.py +++ b/api/core/workflow/graph_events/base.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from core.workflow.enums import NodeType @@ -19,9 +17,9 @@ class GraphNodeEventBase(GraphEngineEvent): node_id: str node_type: NodeType - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" # The version of the node, or "1" if not specified. From d52621fce37f0305fbbd4e4b54d89c4ba7e1c1a1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:49:46 +0800 Subject: [PATCH 06/31] refactor(graph_engine): Merge error strategies into error_handler.py Signed-off-by: -LAN- --- .../graph_engine/error_handling/__init__.py | 14 +- .../error_handling/abort_strategy.py | 40 ----- .../error_handling/default_value_strategy.py | 58 ------- .../error_handling/error_handler.py | 164 ++++++++++++++++-- .../error_handling/fail_branch_strategy.py | 57 ------ .../error_handling/retry_strategy.py | 52 ------ .../graph_engine/protocols/error_strategy.py | 31 ---- 7 files changed, 150 insertions(+), 266 deletions(-) delete mode 100644 api/core/workflow/graph_engine/error_handling/abort_strategy.py delete mode 100644 api/core/workflow/graph_engine/error_handling/default_value_strategy.py delete mode 100644 api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py delete mode 100644 api/core/workflow/graph_engine/error_handling/retry_strategy.py delete mode 100644 api/core/workflow/graph_engine/protocols/error_strategy.py diff --git a/api/core/workflow/graph_engine/error_handling/__init__.py b/api/core/workflow/graph_engine/error_handling/__init__.py index 1316710d0d..3189fea0c9 100644 --- a/api/core/workflow/graph_engine/error_handling/__init__.py +++ b/api/core/workflow/graph_engine/error_handling/__init__.py @@ -1,20 +1,12 @@ """ -Error handling strategies for graph engine. +Error handling for graph engine. -This package implements different error recovery strategies using -the Strategy pattern for clean separation of concerns. +This package provides error handling functionality for managing +node execution failures with different recovery strategies. """ -from .abort_strategy import AbortStrategy -from .default_value_strategy import DefaultValueStrategy from .error_handler import ErrorHandler -from .fail_branch_strategy import FailBranchStrategy -from .retry_strategy import RetryStrategy __all__ = [ - "AbortStrategy", - "DefaultValueStrategy", "ErrorHandler", - "FailBranchStrategy", - "RetryStrategy", ] diff --git a/api/core/workflow/graph_engine/error_handling/abort_strategy.py b/api/core/workflow/graph_engine/error_handling/abort_strategy.py deleted file mode 100644 index 4593f004f3..0000000000 --- a/api/core/workflow/graph_engine/error_handling/abort_strategy.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Abort error strategy implementation. -""" - -import logging -from typing import final - -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent - -logger = logging.getLogger(__name__) - - -@final -class AbortStrategy: - """ - Error strategy that aborts execution on failure. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - """ - - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: - """ - Handle error by aborting execution. - - Args: - event: The failure event - graph: The workflow graph - retry_count: Current retry attempt count (unused) - - Returns: - None - signals abortion - """ - _ = graph - _ = retry_count - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - - # Return None to signal that execution should stop - return None diff --git a/api/core/workflow/graph_engine/error_handling/default_value_strategy.py b/api/core/workflow/graph_engine/error_handling/default_value_strategy.py deleted file mode 100644 index 3cdcec88e5..0000000000 --- a/api/core/workflow/graph_engine/error_handling/default_value_strategy.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Default value error strategy implementation. -""" - -from typing import final - -from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent -from core.workflow.node_events import NodeRunResult - - -@final -class DefaultValueStrategy: - """ - Error strategy that uses default values on failure. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - """ - - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: - """ - Handle error by using default values. - - Args: - event: The failure event - graph: The workflow graph - retry_count: Current retry attempt count (unused) - - Returns: - NodeRunExceptionEvent with default values - """ - _ = retry_count - node = graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.DEFAULT_VALUE, - }, - ), - error=event.error, - ) diff --git a/api/core/workflow/graph_engine/error_handling/error_handler.py b/api/core/workflow/graph_engine/error_handling/error_handler.py index b51d7e4dad..aa6e6287e0 100644 --- a/api/core/workflow/graph_engine/error_handling/error_handler.py +++ b/api/core/workflow/graph_engine/error_handling/error_handler.py @@ -2,20 +2,31 @@ Main error handler that coordinates error strategies. """ +import logging +import time from typing import TYPE_CHECKING, final -from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum +from core.workflow.enums import ( + ErrorStrategy as ErrorStrategyEnum, +) +from core.workflow.enums import ( + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent - -from .abort_strategy import AbortStrategy -from .default_value_strategy import DefaultValueStrategy -from .fail_branch_strategy import FailBranchStrategy -from .retry_strategy import RetryStrategy +from core.workflow.graph_events import ( + GraphNodeEventBase, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunRetryEvent, +) +from core.workflow.node_events import NodeRunResult if TYPE_CHECKING: from ..domain import GraphExecution +logger = logging.getLogger(__name__) + @final class ErrorHandler: @@ -38,12 +49,6 @@ class ErrorHandler: self._graph = graph self._graph_execution = graph_execution - # Initialize strategies - self._abort_strategy = AbortStrategy() - self._retry_strategy = RetryStrategy() - self._fail_branch_strategy = FailBranchStrategy() - self._default_value_strategy = DefaultValueStrategy() - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: """ Handle a node failure event. @@ -64,7 +69,7 @@ class ErrorHandler: # First check if retry is configured and not exhausted if node.retry and retry_count < node.retry_config.max_retries: - result = self._retry_strategy.handle_error(event, self._graph, retry_count) + result = self._handle_retry(event, retry_count) if result: # Retry count will be incremented when NodeRunRetryEvent is handled return result @@ -74,8 +79,133 @@ class ErrorHandler: match strategy: case None: - return self._abort_strategy.handle_error(event, self._graph, retry_count) + return self._handle_abort(event) case ErrorStrategyEnum.FAIL_BRANCH: - return self._fail_branch_strategy.handle_error(event, self._graph, retry_count) + return self._handle_fail_branch(event) case ErrorStrategyEnum.DEFAULT_VALUE: - return self._default_value_strategy.handle_error(event, self._graph, retry_count) + return self._handle_default_value(event) + + def _handle_abort(self, event: NodeRunFailedEvent): + """ + Handle error by aborting execution. + + This is the default strategy when no other strategy is specified. + It stops the entire graph execution when a node fails. + + Args: + event: The failure event + + Returns: + None - signals abortion + """ + logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) + # Return None to signal that execution should stop + + def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): + """ + Handle error by retrying the node. + + This strategy re-attempts node execution up to a configured + maximum number of retries with configurable intervals. + + Args: + event: The failure event + retry_count: Current retry attempt count + + Returns: + NodeRunRetryEvent if retry should occur, None otherwise + """ + node = self._graph.nodes[event.node_id] + + # Check if we've exceeded max retries + if not node.retry or retry_count >= node.retry_config.max_retries: + return None + + # Wait for retry interval + time.sleep(node.retry_config.retry_interval_seconds) + + # Create retry event + return NodeRunRetryEvent( + id=event.id, + node_title=node.title, + node_id=event.node_id, + node_type=event.node_type, + node_run_result=event.node_run_result, + start_at=event.start_at, + error=event.error, + retry_index=retry_count + 1, + ) + + def _handle_fail_branch(self, event: NodeRunFailedEvent): + """ + Handle error by taking the fail branch. + + This strategy converts failures to exceptions and routes execution + through a designated fail-branch edge. + + Args: + event: The failure event + + Returns: + NodeRunExceptionEvent to continue via fail branch + """ + outputs = { + "error_message": event.node_run_result.error, + "error_type": event.node_run_result.error_type, + } + + return NodeRunExceptionEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + start_at=event.start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + inputs=event.node_run_result.inputs, + process_data=event.node_run_result.process_data, + outputs=outputs, + edge_source_handle="fail-branch", + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, + }, + ), + error=event.error, + ) + + def _handle_default_value(self, event: NodeRunFailedEvent): + """ + Handle error by using default values. + + This strategy allows nodes to fail gracefully by providing + predefined default output values. + + Args: + event: The failure event + + Returns: + NodeRunExceptionEvent with default values + """ + node = self._graph.nodes[event.node_id] + + outputs = { + **node.default_value_dict, + "error_message": event.node_run_result.error, + "error_type": event.node_run_result.error_type, + } + + return NodeRunExceptionEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + start_at=event.start_at, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + inputs=event.node_run_result.inputs, + process_data=event.node_run_result.process_data, + outputs=outputs, + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, + }, + ), + error=event.error, + ) diff --git a/api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py b/api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py deleted file mode 100644 index 1c156b5be1..0000000000 --- a/api/core/workflow/graph_engine/error_handling/fail_branch_strategy.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Fail branch error strategy implementation. -""" - -from typing import final - -from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent -from core.workflow.node_events import NodeRunResult - - -@final -class FailBranchStrategy: - """ - Error strategy that continues execution via a fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - """ - - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: - """ - Handle error by taking the fail branch. - - Args: - event: The failure event - graph: The workflow graph - retry_count: Current retry attempt count (unused) - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - _ = graph - _ = retry_count - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.FAIL_BRANCH, - }, - ), - error=event.error, - ) diff --git a/api/core/workflow/graph_engine/error_handling/retry_strategy.py b/api/core/workflow/graph_engine/error_handling/retry_strategy.py deleted file mode 100644 index e4010b6bdb..0000000000 --- a/api/core/workflow/graph_engine/error_handling/retry_strategy.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Retry error strategy implementation. -""" - -import time -from typing import final - -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent - - -@final -class RetryStrategy: - """ - Error strategy that retries failed nodes. - - This strategy re-attempts node execution up to a configured - maximum number of retries with configurable intervals. - """ - - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: - """ - Handle error by retrying the node. - - Args: - event: The failure event - graph: The workflow graph - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Wait for retry interval - time.sleep(node.retry_config.retry_interval_seconds) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) diff --git a/api/core/workflow/graph_engine/protocols/error_strategy.py b/api/core/workflow/graph_engine/protocols/error_strategy.py deleted file mode 100644 index bf8b316423..0000000000 --- a/api/core/workflow/graph_engine/protocols/error_strategy.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Base error strategy protocol. -""" - -from typing import Protocol - -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent - - -class ErrorStrategy(Protocol): - """ - Protocol for error handling strategies. - - Each strategy implements a different approach to handling - node execution failures. - """ - - def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Args: - event: The failure event - graph: The workflow graph - retry_count: Current retry attempt count - - Returns: - Optional new event to process, or None to stop - """ - ... From f17c71e08a07e5ea9ad6f735d26c887e4d114310 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:55:30 +0800 Subject: [PATCH 07/31] refactor(graph_engine): Move GraphStateManager to single file package. Signed-off-by: -LAN- --- .../event_management/event_handlers.py | 4 ++-- .../workflow/graph_engine/graph_engine.py | 4 ++-- ...tate_manager.py => graph_state_manager.py} | 23 +++---------------- .../graph_traversal/edge_processor.py | 4 ++-- .../graph_traversal/skip_propagator.py | 4 ++-- .../orchestration/execution_coordinator.py | 4 ++-- .../graph_engine/state_management/__init__.py | 12 ---------- 7 files changed, 13 insertions(+), 42 deletions(-) rename api/core/workflow/graph_engine/{state_management/unified_state_manager.py => graph_state_manager.py} (91%) delete mode 100644 api/core/workflow/graph_engine/state_management/__init__.py diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 3ab69776a4..10e7d421af 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -32,8 +32,8 @@ from ..response_coordinator import ResponseStreamCoordinator if TYPE_CHECKING: from ..error_handling import ErrorHandler + from ..graph_state_manager import GraphStateManager from ..graph_traversal import EdgeProcessor - from ..state_management import UnifiedStateManager from .event_manager import EventManager logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class EventHandler: response_coordinator: ResponseStreamCoordinator, event_collector: "EventManager", edge_processor: "EdgeProcessor", - state_manager: "UnifiedStateManager", + state_manager: "GraphStateManager", error_handler: "ErrorHandler", ) -> None: """ diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b6563058bc..019a1aaade 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -33,12 +33,12 @@ from .domain import ExecutionContext, GraphExecution from .entities.commands import AbortCommand from .error_handling import ErrorHandler from .event_management import EventHandler, EventManager +from .graph_state_manager import GraphStateManager from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel from .response_coordinator import ResponseStreamCoordinator -from .state_management import UnifiedStateManager from .worker_management import WorkerPool logger = logging.getLogger(__name__) @@ -110,7 +110,7 @@ class GraphEngine: # === State Management === # Unified state manager handles all node state transitions and queue operations - self._state_manager = UnifiedStateManager(self._graph, self._ready_queue) + self._state_manager = GraphStateManager(self._graph, self._ready_queue) # === Response Coordination === # Coordinates response streaming from response nodes diff --git a/api/core/workflow/graph_engine/state_management/unified_state_manager.py b/api/core/workflow/graph_engine/graph_state_manager.py similarity index 91% rename from api/core/workflow/graph_engine/state_management/unified_state_manager.py rename to api/core/workflow/graph_engine/graph_state_manager.py index 258b84c341..efc3992ac9 100644 --- a/api/core/workflow/graph_engine/state_management/unified_state_manager.py +++ b/api/core/workflow/graph_engine/graph_state_manager.py @@ -1,8 +1,5 @@ """ -Unified state manager that combines node, edge, and execution tracking. - -This is a proposed simplification that merges NodeStateManager, EdgeStateManager, -and ExecutionTracker into a single cohesive class. +Graph state manager that combines node, edge, and execution tracking. """ import queue @@ -23,24 +20,10 @@ class EdgeStateAnalysis(TypedDict): @final -class UnifiedStateManager: - """ - Unified manager for all graph state operations. - - This class combines the responsibilities of: - - NodeStateManager: Node state transitions and ready queue - - EdgeStateManager: Edge state transitions and analysis - - ExecutionTracker: Tracking executing nodes - - Benefits: - - Single lock for all state operations (reduced contention) - - Cohesive state management interface - - Simplified dependency injection - """ - +class GraphStateManager: def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None: """ - Initialize the unified state manager. + Initialize the state manager. Args: graph: The workflow graph diff --git a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py index c5634ed984..9bd0f86fbf 100644 --- a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py +++ b/api/core/workflow/graph_engine/graph_traversal/edge_processor.py @@ -9,8 +9,8 @@ from core.workflow.enums import NodeExecutionType from core.workflow.graph import Edge, Graph from core.workflow.graph_events import NodeRunStreamChunkEvent +from ..graph_state_manager import GraphStateManager from ..response_coordinator import ResponseStreamCoordinator -from ..state_management import UnifiedStateManager if TYPE_CHECKING: from .skip_propagator import SkipPropagator @@ -29,7 +29,7 @@ class EdgeProcessor: def __init__( self, graph: Graph, - state_manager: UnifiedStateManager, + state_manager: GraphStateManager, response_coordinator: ResponseStreamCoordinator, skip_propagator: "SkipPropagator", ) -> None: diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py index 51ab3c6739..78f8ecdcdf 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py @@ -7,7 +7,7 @@ from typing import final from core.workflow.graph import Edge, Graph -from ..state_management import UnifiedStateManager +from ..graph_state_manager import GraphStateManager @final @@ -22,7 +22,7 @@ class SkipPropagator: def __init__( self, graph: Graph, - state_manager: UnifiedStateManager, + state_manager: GraphStateManager, ) -> None: """ Initialize the skip propagator. diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py index 234a3607c3..b35e8bb6d8 100644 --- a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py +++ b/api/core/workflow/graph_engine/orchestration/execution_coordinator.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, final from ..command_processing import CommandProcessor from ..domain import GraphExecution from ..event_management import EventManager -from ..state_management import UnifiedStateManager +from ..graph_state_manager import GraphStateManager from ..worker_management import WorkerPool if TYPE_CHECKING: @@ -26,7 +26,7 @@ class ExecutionCoordinator: def __init__( self, graph_execution: GraphExecution, - state_manager: UnifiedStateManager, + state_manager: GraphStateManager, event_handler: "EventHandler", event_collector: EventManager, command_processor: CommandProcessor, diff --git a/api/core/workflow/graph_engine/state_management/__init__.py b/api/core/workflow/graph_engine/state_management/__init__.py deleted file mode 100644 index 9a632a3b9f..0000000000 --- a/api/core/workflow/graph_engine/state_management/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -State management subsystem for graph engine. - -This package manages node states, edge states, and execution tracking -during workflow graph execution. -""" - -from .unified_state_manager import UnifiedStateManager - -__all__ = [ - "UnifiedStateManager", -] From 9cf2b2b231e737688f887301ab5bb0172c409e41 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 02:22:58 +0800 Subject: [PATCH 08/31] fix: type errors Signed-off-by: -LAN- --- api/core/workflow/workflow_type_encoder.py | 4 +++- api/pyrightconfig.json | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 6eac2dd6b4..6b2657b4dc 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -11,7 +11,9 @@ from core.variables import Segment class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: result = self._to_json_encodable_recursive(value) - return result if isinstance(result, Mapping) or result is None else dict(result) + if isinstance(result, Mapping) or result is None: + return result + return {} def _to_json_encodable_recursive(self, value: Any): if value is None: diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 7c59c2ca28..61ed3ac3b4 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -12,7 +12,7 @@ "core/ops", "core/tools", "core/model_runtime", - "core/workflow", + "core/workflow/nodes", "core/app/app_config/easy_ui_based_app/dataset" ], "typeCheckingMode": "strict", From 80f39963f1762032af67ba31887f2c4dd6c19d0e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 02:32:24 +0800 Subject: [PATCH 09/31] chore: add import lint to CI Signed-off-by: -LAN- --- .github/workflows/style.yml | 4 ++++ api/.importlinter | 15 ++++----------- dev/reformat | 3 +++ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c01f408628..302cd36229 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -43,6 +43,10 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: uv sync --project api --dev + - name: Run Import Linter + if: steps.changed-files.outputs.any_changed == 'true' + run: uv run --directory api --dev lint-imports + - name: Run Basedpyright Checks if: steps.changed-files.outputs.any_changed == 'true' run: dev/basedpyright-check diff --git a/api/.importlinter b/api/.importlinter index 4380e8c18e..9a593c288a 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -22,14 +22,15 @@ containers = ignore_imports = core.workflow.nodes.base.node -> core.workflow.graph_events core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events + core.workflow.nodes.loop.loop_node -> core.workflow.graph_events + + core.workflow.nodes.node_factory -> core.workflow.graph core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine core.workflow.nodes.iteration.iteration_node -> core.workflow.graph core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels - core.workflow.nodes.loop.loop_node -> core.workflow.graph_events core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels - core.workflow.nodes.node_factory -> core.workflow.graph [importlinter:contract:rsc] name = RSC @@ -59,7 +60,7 @@ layers = event_management error_handling graph_traversal - state_management + graph_state_manager worker_management domain containers = @@ -86,14 +87,6 @@ forbidden_modules = core.workflow.graph_engine.command_processing core.workflow.graph_engine.event_management -[importlinter:contract:error-handling-strategies] -name = Error Handling Strategies -type = independence -modules = - core.workflow.graph_engine.error_handling.abort_strategy - core.workflow.graph_engine.error_handling.retry_strategy - core.workflow.graph_engine.error_handling.fail_branch_strategy - core.workflow.graph_engine.error_handling.default_value_strategy [importlinter:contract:graph-traversal-components] name = Graph Traversal Components diff --git a/dev/reformat b/dev/reformat index 258b47b3bf..6966267193 100755 --- a/dev/reformat +++ b/dev/reformat @@ -5,6 +5,9 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/.." +# Import linter +uv run --directory api --dev lint-imports + # run ruff linter uv run --directory api --dev ruff check --fix ./ From 836ed1f380ed67e51705d207324f0b6d3a589c26 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 02:35:05 +0800 Subject: [PATCH 10/31] refactor(graph_engine): Move ErrorHandler into a single file package Signed-off-by: -LAN- --- api/.importlinter | 2 +- .../{error_handling => }/error_handler.py | 2 +- .../workflow/graph_engine/error_handling/__init__.py | 12 ------------ .../graph_engine/event_management/event_handlers.py | 2 +- api/core/workflow/graph_engine/graph_engine.py | 2 +- 5 files changed, 4 insertions(+), 16 deletions(-) rename api/core/workflow/graph_engine/{error_handling => }/error_handler.py (99%) delete mode 100644 api/core/workflow/graph_engine/error_handling/__init__.py diff --git a/api/.importlinter b/api/.importlinter index 9a593c288a..98fe5f50bb 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -58,7 +58,7 @@ layers = orchestration command_processing event_management - error_handling + error_handler graph_traversal graph_state_manager worker_management diff --git a/api/core/workflow/graph_engine/error_handling/error_handler.py b/api/core/workflow/graph_engine/error_handler.py similarity index 99% rename from api/core/workflow/graph_engine/error_handling/error_handler.py rename to api/core/workflow/graph_engine/error_handler.py index aa6e6287e0..62e144c12a 100644 --- a/api/core/workflow/graph_engine/error_handling/error_handler.py +++ b/api/core/workflow/graph_engine/error_handler.py @@ -23,7 +23,7 @@ from core.workflow.graph_events import ( from core.workflow.node_events import NodeRunResult if TYPE_CHECKING: - from ..domain import GraphExecution + from .domain import GraphExecution logger = logging.getLogger(__name__) diff --git a/api/core/workflow/graph_engine/error_handling/__init__.py b/api/core/workflow/graph_engine/error_handling/__init__.py deleted file mode 100644 index 3189fea0c9..0000000000 --- a/api/core/workflow/graph_engine/error_handling/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Error handling for graph engine. - -This package provides error handling functionality for managing -node execution failures with different recovery strategies. -""" - -from .error_handler import ErrorHandler - -__all__ = [ - "ErrorHandler", -] diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 10e7d421af..63929381de 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -31,7 +31,7 @@ from ..domain.graph_execution import GraphExecution from ..response_coordinator import ResponseStreamCoordinator if TYPE_CHECKING: - from ..error_handling import ErrorHandler + from ..error_handler import ErrorHandler from ..graph_state_manager import GraphStateManager from ..graph_traversal import EdgeProcessor from .event_manager import EventManager diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 019a1aaade..ff56605d3d 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -31,7 +31,7 @@ from models.enums import UserFrom from .command_processing import AbortCommandHandler, CommandProcessor from .domain import ExecutionContext, GraphExecution from .entities.commands import AbortCommand -from .error_handling import ErrorHandler +from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager from .graph_state_manager import GraphStateManager from .graph_traversal import EdgeProcessor, SkipPropagator From 9796cede72baa3e1b64a2cd7e15504e14554543a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 02:54:01 +0800 Subject: [PATCH 11/31] fix: add missing type field to node configurations in integration tests - Added 'type' field to all node data configurations in test files - Fixed test_code.py: added 'type: code' to all code node configs - Fixed test_http.py: added 'type: http-request' to all HTTP node configs - Fixed test_template_transform.py: added 'type: template-transform' to template node config - Fixed test_tool.py: added 'type: tool' to all tool node configs - Added setup_code_executor_mock fixture to test_execute_code_scientific_notation These changes fix the ValueError: 'Node X missing or invalid type information' errors that were occurring due to changes in the node factory validation requirements. --- .../integration_tests/workflow/nodes/test_code.py | 8 +++++++- .../integration_tests/workflow/nodes/test_http.py | 14 ++++++++++++++ .../workflow/nodes/test_template_transform.py | 1 + .../integration_tests/workflow/nodes/test_tool.py | 2 ++ 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index a8f3253b35..e2f3a74bf9 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -89,6 +89,7 @@ def test_execute_code(setup_code_executor_mock): code_config = { "id": "code", "data": { + "type": "code", "outputs": { "result": { "type": "number", @@ -135,6 +136,7 @@ def test_execute_code_output_validator(setup_code_executor_mock): code_config = { "id": "code", "data": { + "type": "code", "outputs": { "result": { "type": "string", @@ -180,6 +182,7 @@ def test_execute_code_output_validator_depth(): code_config = { "id": "code", "data": { + "type": "code", "outputs": { "string_validator": { "type": "string", @@ -298,6 +301,7 @@ def test_execute_code_output_object_list(): code_config = { "id": "code", "data": { + "type": "code", "outputs": { "object_list": { "type": "array[object]", @@ -358,7 +362,8 @@ def test_execute_code_output_object_list(): node._transform_result(result, node._node_data.outputs) -def test_execute_code_scientific_notation(): +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +def test_execute_code_scientific_notation(setup_code_executor_mock): code = """ def main(): return { @@ -370,6 +375,7 @@ def test_execute_code_scientific_notation(): code_config = { "id": "code", "data": { + "type": "code", "outputs": { "result": { "type": "number", diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 5e900342ce..ea99beacaa 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -77,6 +77,7 @@ def test_get(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -110,6 +111,7 @@ def test_no_auth(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -139,6 +141,7 @@ def test_custom_authorization_header(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -231,6 +234,7 @@ def test_bearer_authorization_with_custom_header_ignored(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -271,6 +275,7 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -310,6 +315,7 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -343,6 +349,7 @@ def test_template(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -378,6 +385,7 @@ def test_json(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "post", @@ -420,6 +428,7 @@ def test_x_www_form_urlencoded(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "post", @@ -467,6 +476,7 @@ def test_form_data(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "post", @@ -517,6 +527,7 @@ def test_none_data(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "post", @@ -550,6 +561,7 @@ def test_mock_404(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -579,6 +591,7 @@ def test_multi_colons_parse(setup_http_mock): config={ "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", @@ -635,6 +648,7 @@ def test_nested_object_variable_selector(setup_http_mock): { "id": "1", "data": { + "type": "http-request", "title": "http", "desc": "", "method": "get", diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 02a8460ce6..53252c7f2e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -20,6 +20,7 @@ def test_execute_code(setup_code_executor_mock): config = { "id": "1", "data": { + "type": "template-transform", "title": "123", "variables": [ { diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 780fe0bee6..16d44d1eaf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -70,6 +70,7 @@ def test_tool_variable_invoke(): config={ "id": "1", "data": { + "type": "tool", "title": "a", "desc": "a", "provider_id": "time", @@ -101,6 +102,7 @@ def test_tool_mixed_invoke(): config={ "id": "1", "data": { + "type": "tool", "title": "a", "desc": "a", "provider_id": "time", From 7e69403ddaf258b5703160dff1940ccf476f81a3 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 03:12:33 +0800 Subject: [PATCH 12/31] refactor(graph_engine): use singledispatchmethod in event_handler Signed-off-by: -LAN- --- .../event_management/event_handlers.py | 73 ++++++++----------- .../graph_engine/orchestration/dispatcher.py | 2 +- 2 files changed, 33 insertions(+), 42 deletions(-) diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 63929381de..244f4a4d86 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -3,6 +3,7 @@ Event handler implementations for different event types. """ import logging +from functools import singledispatchmethod from typing import TYPE_CHECKING, final from core.workflow.entities import GraphRuntimeState @@ -81,7 +82,7 @@ class EventHandler: self._state_manager = state_manager self._error_handler = error_handler - def handle_event(self, event: GraphNodeEventBase) -> None: + def dispatch(self, event: GraphNodeEventBase) -> None: """ Handle any node event by dispatching to the appropriate handler. @@ -92,42 +93,27 @@ class EventHandler: if event.in_loop_id or event.in_iteration_id: self._event_collector.collect(event) return + return self._dispatch(event) - # Handle specific event types - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - elif isinstance(event, NodeRunStreamChunkEvent): - self._handle_stream_chunk(event) - elif isinstance(event, NodeRunSucceededEvent): - self._handle_node_succeeded(event) - elif isinstance(event, NodeRunFailedEvent): - self._handle_node_failed(event) - elif isinstance(event, NodeRunExceptionEvent): - self._handle_node_exception(event) - elif isinstance(event, NodeRunRetryEvent): - self._handle_node_retry(event) - elif isinstance( - event, - ( - NodeRunIterationStartedEvent, - NodeRunIterationNextEvent, - NodeRunIterationSucceededEvent, - NodeRunIterationFailedEvent, - NodeRunLoopStartedEvent, - NodeRunLoopNextEvent, - NodeRunLoopSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunAgentLogEvent, - ), - ): - # Iteration and loop events are collected directly - self._event_collector.collect(event) - else: - # Collect unhandled events - self._event_collector.collect(event) - logger.warning("Unhandled event type: %s", type(event).__name__) + @singledispatchmethod + def _dispatch(self, event: GraphNodeEventBase) -> None: + self._event_collector.collect(event) + logger.warning("Unhandled event type: %s", type(event).__name__) - def _handle_node_started(self, event: NodeRunStartedEvent) -> None: + @_dispatch.register(NodeRunIterationStartedEvent) + @_dispatch.register(NodeRunIterationNextEvent) + @_dispatch.register(NodeRunIterationSucceededEvent) + @_dispatch.register(NodeRunIterationFailedEvent) + @_dispatch.register(NodeRunLoopStartedEvent) + @_dispatch.register(NodeRunLoopNextEvent) + @_dispatch.register(NodeRunLoopSucceededEvent) + @_dispatch.register(NodeRunLoopFailedEvent) + @_dispatch.register(NodeRunAgentLogEvent) + def _(self, event: GraphNodeEventBase) -> None: + self._event_collector.collect(event) + + @_dispatch.register + def _(self, event: NodeRunStartedEvent) -> None: """ Handle node started event. @@ -144,7 +130,8 @@ class EventHandler: # Collect the event self._event_collector.collect(event) - def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunStreamChunkEvent) -> None: """ Handle stream chunk event with full processing. @@ -158,7 +145,8 @@ class EventHandler: for stream_event in streaming_events: self._event_collector.collect(stream_event) - def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunSucceededEvent) -> None: """ Handle node success by coordinating subsystems. @@ -208,7 +196,8 @@ class EventHandler: # Collect the event self._event_collector.collect(event) - def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunFailedEvent) -> None: """ Handle node failure using error handler. @@ -223,14 +212,15 @@ class EventHandler: if result: # Process the resulting event (retry, exception, etc.) - self.handle_event(result) + self.dispatch(result) else: # Abort execution self._graph_execution.fail(RuntimeError(event.error)) self._event_collector.collect(event) self._state_manager.finish_execution(event.node_id) - def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunExceptionEvent) -> None: """ Handle node exception event (fail-branch strategy). @@ -241,7 +231,8 @@ class EventHandler: node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) node_execution.mark_taken() - def _handle_node_retry(self, event: NodeRunRetryEvent) -> None: + @_dispatch.register + def _(self, event: NodeRunRetryEvent) -> None: """ Handle node retry event. diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index bb4720a684..a7229ce4e8 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -86,7 +86,7 @@ class Dispatcher: try: event = self._event_queue.get(timeout=0.1) # Route to the event handler - self._event_handler.handle_event(event) + self._event_handler.dispatch(event) self._event_queue.task_done() except queue.Empty: # Check if execution is complete From f56fccee9d21e8e0ec2f8b9cd8d2974b9150463f Mon Sep 17 00:00:00 2001 From: quicksand Date: Wed, 10 Sep 2025 13:47:47 +0800 Subject: [PATCH 13/31] fix: workflow knowledge query raise error (#25465) --- api/core/app/apps/workflow/app_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 64ed9369d2..b009dc7715 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -105,6 +105,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): graph_runtime_state=graph_runtime_state, workflow_id=self._workflow.id, tenant_id=self._workflow.tenant_id, + user_id=self.application_generate_entity.user_id, ) # RUN WORKFLOW From 00a1af850630cce25a74d9bb81fec557cfcf9816 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 20:59:34 +0800 Subject: [PATCH 14/31] refactor(graph_engine): use singledispatch in Node Signed-off-by: -LAN- --- api/core/workflow/nodes/agent/agent_node.py | 12 ++- api/core/workflow/nodes/base/node.py | 90 +++++++++++---------- api/core/workflow/nodes/tool/tool_node.py | 4 +- api/core/workflow/workflow_cycle_manager.py | 4 +- api/services/workflow_service.py | 8 +- 5 files changed, 64 insertions(+), 54 deletions(-) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 67f16743c3..3c7dcb8d66 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -33,7 +33,13 @@ from core.workflow.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.node_events import ( + AgentLogEvent, + NodeEventBase, + NodeRunResult, + StreamChunkEvent, + StreamCompletedEvent, +) from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node @@ -93,7 +99,7 @@ class AgentNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator: + def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError try: @@ -482,7 +488,7 @@ class AgentNode(Node): node_type: NodeType, node_id: str, node_execution_id: str, - ) -> Generator: + ) -> Generator[NodeEventBase, None, None]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 8816e22a85..e5db872e3b 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,8 @@ import logging from abc import abstractmethod -from collections.abc import Callable, Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from collections.abc import Generator, Mapping, Sequence +from functools import singledispatchmethod +from typing import TYPE_CHECKING, Any, ClassVar from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom @@ -88,14 +89,14 @@ class Node: def init_node_data(self, data: Mapping[str, Any]) -> None: ... @abstractmethod - def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]": + def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: """ Run node :return: """ raise NotImplementedError - def run(self) -> "Generator[GraphNodeEventBase, None, None]": + def run(self) -> Generator[GraphNodeEventBase, None, None]: # Generate a single node execution ID to use for all events if not self._node_execution_id: self._node_execution_id = str(uuid4()) @@ -142,8 +143,9 @@ class Node: # Handle event stream for event in result: - if isinstance(event, NodeEventBase): - event = self._convert_node_event_to_graph_node_event(event) + # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase + if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] + event = self._dispatch(event) if not event.in_iteration_id and not event.in_loop_id: event.id = self._node_execution_id @@ -240,7 +242,7 @@ class Node: return False @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return {} @classmethod @@ -261,7 +263,7 @@ class Node: # to BaseNodeData properties in a type-safe way @abstractmethod - def _get_error_strategy(self) -> Optional["ErrorStrategy"]: + def _get_error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" ... @@ -276,7 +278,7 @@ class Node: ... @abstractmethod - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: """Get the node description.""" ... @@ -292,7 +294,7 @@ class Node: # Public interface properties that delegate to abstract methods @property - def error_strategy(self) -> Optional["ErrorStrategy"]: + def error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" return self._get_error_strategy() @@ -307,7 +309,7 @@ class Node: return self._get_title() @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """Get the node description.""" return self._get_description() @@ -335,29 +337,15 @@ class Node: start_at=self._start_at, node_run_result=result, ) - raise Exception(f"result status {result.status} not supported") + case _: + raise Exception(f"result status {result.status} not supported") - def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase: - handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = { - StreamChunkEvent: self._handle_stream_chunk_event, - StreamCompletedEvent: self._handle_stream_completed_event, - AgentLogEvent: self._handle_agent_log_event, - LoopStartedEvent: self._handle_loop_started_event, - LoopNextEvent: self._handle_loop_next_event, - LoopSucceededEvent: self._handle_loop_succeeded_event, - LoopFailedEvent: self._handle_loop_failed_event, - IterationStartedEvent: self._handle_iteration_started_event, - IterationNextEvent: self._handle_iteration_next_event, - IterationSucceededEvent: self._handle_iteration_succeeded_event, - IterationFailedEvent: self._handle_iteration_failed_event, - RunRetrieverResourceEvent: self._handle_run_retriever_resource_event, - } - handler = handler_maps.get(type(event)) - if not handler: - raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - return handler(event) + @singledispatchmethod + def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase: + raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: + @_dispatch.register + def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: return NodeRunStreamChunkEvent( id=self._node_execution_id, node_id=self._node_id, @@ -367,7 +355,8 @@ class Node: is_final=event.is_final, ) - def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: + @_dispatch.register + def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( @@ -386,9 +375,13 @@ class Node: node_run_result=event.node_run_result, error=event.node_run_result.error, ) - raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}") + case _: + raise NotImplementedError( + f"Node {self._node_id} does not support status {event.node_run_result.status}" + ) - def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: + @_dispatch.register + def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: return NodeRunAgentLogEvent( id=self._node_execution_id, node_id=self._node_id, @@ -403,7 +396,8 @@ class Node: metadata=event.metadata, ) - def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: + @_dispatch.register + def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: return NodeRunLoopStartedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -415,7 +409,8 @@ class Node: predecessor_node_id=event.predecessor_node_id, ) - def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: + @_dispatch.register + def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: return NodeRunLoopNextEvent( id=self._node_execution_id, node_id=self._node_id, @@ -425,7 +420,8 @@ class Node: pre_loop_output=event.pre_loop_output, ) - def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: + @_dispatch.register + def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: return NodeRunLoopSucceededEvent( id=self._node_execution_id, node_id=self._node_id, @@ -438,7 +434,8 @@ class Node: steps=event.steps, ) - def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: + @_dispatch.register + def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: return NodeRunLoopFailedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -452,7 +449,8 @@ class Node: error=event.error, ) - def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: + @_dispatch.register + def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: return NodeRunIterationStartedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -464,7 +462,8 @@ class Node: predecessor_node_id=event.predecessor_node_id, ) - def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: + @_dispatch.register + def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: return NodeRunIterationNextEvent( id=self._node_execution_id, node_id=self._node_id, @@ -474,7 +473,8 @@ class Node: pre_iteration_output=event.pre_iteration_output, ) - def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: + @_dispatch.register + def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: return NodeRunIterationSucceededEvent( id=self._node_execution_id, node_id=self._node_id, @@ -487,7 +487,8 @@ class Node: steps=event.steps, ) - def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: + @_dispatch.register + def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: return NodeRunIterationFailedEvent( id=self._node_execution_id, node_id=self._node_id, @@ -501,7 +502,8 @@ class Node: error=event.error, ) - def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: + @_dispatch.register + def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: return NodeRunRetrieverResourceEvent( id=self._node_execution_id, node_id=self._node_id, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c2c9def30c..6829d649d3 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -19,7 +19,7 @@ from core.workflow.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser @@ -55,7 +55,7 @@ class ToolNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator: + def _run(self) -> Generator[NodeEventBase, None, None]: """ Run the tool node """ diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 5d6362b1c4..cbf4cfd136 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -356,8 +356,8 @@ class WorkflowCycleManager: workflow_execution: WorkflowExecution, event: QueueNodeStartedEvent, status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, - created_at: Optional[datetime] = None, + error: str | None = None, + created_at: datetime | None = None, ) -> WorkflowNodeExecution: """Create a node execution from an event.""" now = naive_utc_now() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 593d577f0e..05cd9610ef 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -557,7 +557,9 @@ class WorkflowService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: + def get_default_block_config( + self, node_type: str, filters: Mapping[str, object] | None = None + ) -> Mapping[str, object]: """ Get default config of node. :param node_type: node type @@ -568,12 +570,12 @@ class WorkflowService: # return default block config if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: - return None + return {} node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] default_config = node_class.get_default_config(filters=filters) if not default_config: - return None + return {} return default_config From b4c1766932e3f25de5fcbb54a5764b45918b2e90 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 21:48:05 +0800 Subject: [PATCH 15/31] fix: type errors Signed-off-by: -LAN- --- .../code_executor/code_node_provider.py | 35 ++++++++++++++-- api/core/workflow/entities/agent.py | 4 +- api/core/workflow/entities/run_condition.py | 6 +-- .../workflow/entities/workflow_execution.py | 6 +-- .../entities/workflow_node_execution.py | 28 ++++++------- api/core/workflow/graph_events/agent.py | 4 +- api/core/workflow/graph_events/graph.py | 10 ++--- api/core/workflow/graph_events/iteration.py | 22 +++++----- api/core/workflow/graph_events/loop.py | 22 +++++----- api/core/workflow/graph_events/node.py | 7 ++-- api/core/workflow/node_events/agent.py | 4 +- api/core/workflow/node_events/iteration.py | 22 +++++----- api/core/workflow/node_events/loop.py | 22 +++++----- api/core/workflow/nodes/agent/agent_node.py | 8 ++-- api/core/workflow/nodes/agent/exc.py | 27 ++++++------- api/core/workflow/nodes/answer/answer_node.py | 6 +-- api/core/workflow/nodes/base/entities.py | 12 +++--- api/core/workflow/nodes/base/node.py | 7 ++-- api/core/workflow/nodes/code/code_node.py | 12 +++--- api/core/workflow/nodes/code/entities.py | 6 +-- .../workflow/nodes/document_extractor/node.py | 6 +-- api/core/workflow/nodes/end/end_node.py | 6 +-- .../workflow/nodes/http_request/entities.py | 12 +++--- api/core/workflow/nodes/http_request/node.py | 8 ++-- api/core/workflow/nodes/if_else/entities.py | 8 ++-- .../workflow/nodes/if_else/if_else_node.py | 6 +-- api/core/workflow/nodes/iteration/entities.py | 10 ++--- .../nodes/iteration/iteration_node.py | 2 +- .../nodes/iteration/iteration_start_node.py | 6 +-- .../nodes/knowledge_retrieval/entities.py | 22 +++++----- .../knowledge_retrieval_node.py | 10 ++--- api/core/workflow/nodes/list_operator/node.py | 6 +-- api/core/workflow/nodes/llm/entities.py | 10 ++--- api/core/workflow/nodes/llm/llm_utils.py | 6 +-- api/core/workflow/nodes/llm/node.py | 20 +++++----- api/core/workflow/nodes/loop/entities.py | 10 ++--- api/core/workflow/nodes/loop/loop_end_node.py | 6 +-- api/core/workflow/nodes/loop/loop_node.py | 6 +-- .../workflow/nodes/loop/loop_start_node.py | 6 +-- .../nodes/parameter_extractor/entities.py | 8 ++-- .../parameter_extractor_node.py | 40 +++++++++---------- .../nodes/question_classifier/entities.py | 6 +-- .../question_classifier_node.py | 12 +++--- api/core/workflow/nodes/start/start_node.py | 6 +-- .../template_transform_node.py | 8 ++-- api/core/workflow/nodes/tool/tool_node.py | 6 +-- .../nodes/variable_aggregator/entities.py | 4 +- .../variable_aggregator_node.py | 6 +-- .../nodes/variable_assigner/v1/node.py | 6 +-- .../nodes/variable_assigner/v2/entities.py | 2 +- .../nodes/variable_assigner/v2/node.py | 6 +-- .../workflow_node_execution_repository.py | 6 +-- api/core/workflow/workflow_cycle_manager.py | 32 +++++++-------- api/core/workflow/workflow_entry.py | 6 +-- api/services/workflow_service.py | 4 +- 55 files changed, 305 insertions(+), 289 deletions(-) diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 701208080c..e93e1e4414 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -1,9 +1,33 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import TypedDict from pydantic import BaseModel -class CodeNodeProvider(BaseModel): +class VariableConfig(TypedDict): + variable: str + value_selector: Sequence[str | int] + + +class OutputConfig(TypedDict): + type: str + children: None + + +class CodeConfig(TypedDict): + variables: Sequence[VariableConfig] + code_language: str + code: str + outputs: Mapping[str, OutputConfig] + + +class DefaultConfig(TypedDict): + type: str + config: CodeConfig + + +class CodeNodeProvider(BaseModel, ABC): @staticmethod @abstractmethod def get_language() -> str: @@ -22,11 +46,14 @@ class CodeNodeProvider(BaseModel): pass @classmethod - def get_default_config(cls): + def get_default_config(cls) -> DefaultConfig: return { "type": "code", "config": { - "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}], + "variables": [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ], "code_language": cls.get_language(), "code": cls.get_default_code(), "outputs": {"result": {"type": "string", "children": None}}, diff --git a/api/core/workflow/entities/agent.py b/api/core/workflow/entities/agent.py index e1d9f13e31..2b4d6db76f 100644 --- a/api/core/workflow/entities/agent.py +++ b/api/core/workflow/entities/agent.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel @@ -7,4 +5,4 @@ class AgentNodeStrategyInit(BaseModel): """Agent node strategy initialization data.""" name: str - icon: Optional[str] = None + icon: str | None = None diff --git a/api/core/workflow/entities/run_condition.py b/api/core/workflow/entities/run_condition.py index eedce8842b..7b9a379215 100644 --- a/api/core/workflow/entities/run_condition.py +++ b/api/core/workflow/entities/run_condition.py @@ -1,5 +1,5 @@ import hashlib -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -10,10 +10,10 @@ class RunCondition(BaseModel): type: Literal["branch_identify", "condition"] """condition type""" - branch_identify: Optional[str] = None + branch_identify: str | None = None """branch identify like: sourceHandle, required when type is branch_identify""" - conditions: Optional[list[Condition]] = None + conditions: list[Condition] | None = None """conditions to run the node, required when type is condition""" @property diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index c41a17e165..a8a86d3db2 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -7,7 +7,7 @@ implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -28,7 +28,7 @@ class WorkflowExecution(BaseModel): graph: Mapping[str, Any] = Field(...) inputs: Mapping[str, Any] = Field(...) - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING error_message: str = Field(default="") @@ -37,7 +37,7 @@ class WorkflowExecution(BaseModel): exceptions_count: int = Field(default=0) started_at: datetime = Field(...) - finished_at: Optional[datetime] = None + finished_at: datetime | None = None @property def elapsed_time(self) -> float: diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index b56766232b..15f5161b82 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -39,41 +39,41 @@ class WorkflowNodeExecution(BaseModel): # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: Optional[str] = None + node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow index: int # Sequence number for ordering in trace visualization - predecessor_node_id: Optional[str] = None # ID of the node that executed before this one + predecessor_node_id: str | None = None # ID of the node that executed before this one node_id: str # ID of the node being executed node_type: NodeType # Type of node (e.g., start, llm, knowledge) title: str # Display title of the node # Execution data - inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node - process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data - outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + inputs: Mapping[str, Any] | None = None # Input variables used by this node + process_data: Mapping[str, Any] | None = None # Intermediate processing data + outputs: Mapping[str, Any] | None = None # Output variables produced by this node # Execution state status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: Optional[str] = None # Error message if execution failed + error: str | None = None # Error message if execution failed elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds # Additional metadata - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) # Timing information created_at: datetime # When execution started - finished_at: Optional[datetime] = None # When execution completed + finished_at: datetime | None = None # When execution completed def update_from_mapping( self, - inputs: Optional[Mapping[str, Any]] = None, - process_data: Optional[Mapping[str, Any]] = None, - outputs: Optional[Mapping[str, Any]] = None, - metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None, + inputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, ): """ Update the model from mappings. diff --git a/api/core/workflow/graph_events/agent.py b/api/core/workflow/graph_events/agent.py index 971a2b918e..67d94d25eb 100644 --- a/api/core/workflow/graph_events/agent.py +++ b/api/core/workflow/graph_events/agent.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -14,4 +14,4 @@ class NodeRunAgentLogEvent(GraphAgentNodeEventBase): error: str | None = Field(..., description="error") status: str = Field(..., description="status") data: Mapping[str, Any] = Field(..., description="data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata") + metadata: Mapping[str, Any] | None = Field(default=None, description="metadata") diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index 26ae5db336..4f7e886519 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -10,7 +10,7 @@ class GraphRunStartedEvent(BaseGraphEvent): class GraphRunSucceededEvent(BaseGraphEvent): - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class GraphRunFailedEvent(BaseGraphEvent): @@ -20,11 +20,11 @@ class GraphRunFailedEvent(BaseGraphEvent): class GraphRunPartialSucceededEvent(BaseGraphEvent): exceptions_count: int = Field(..., description="exception count") - outputs: Optional[dict[str, Any]] = None + outputs: dict[str, Any] | None = None class GraphRunAbortedEvent(BaseGraphEvent): """Event emitted when a graph run is aborted by user command.""" - reason: Optional[str] = Field(default=None, description="reason for abort") - outputs: Optional[dict[str, Any]] = Field(default=None, description="partial outputs if any") + reason: str | None = Field(default=None, description="reason for abort") + outputs: dict[str, Any] | None = Field(default=None, description="partial outputs if any") diff --git a/api/core/workflow/graph_events/iteration.py b/api/core/workflow/graph_events/iteration.py index 908a531d91..3d507dbe46 100644 --- a/api/core/workflow/graph_events/iteration.py +++ b/api/core/workflow/graph_events/iteration.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -10,31 +10,31 @@ from .base import GraphNodeEventBase class NodeRunIterationStartedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class NodeRunIterationNextEvent(GraphNodeEventBase): node_title: str index: int = Field(..., description="index") - pre_iteration_output: Optional[Any] = None + pre_iteration_output: Any = None class NodeRunIterationSucceededEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 class NodeRunIterationFailedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/graph_events/loop.py b/api/core/workflow/graph_events/loop.py index 9982d876ba..c0b540949b 100644 --- a/api/core/workflow/graph_events/loop.py +++ b/api/core/workflow/graph_events/loop.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -10,31 +10,31 @@ from .base import GraphNodeEventBase class NodeRunLoopStartedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class NodeRunLoopNextEvent(GraphNodeEventBase): node_title: str index: int = Field(..., description="index") - pre_loop_output: Optional[Any] = None + pre_loop_output: Any = None class NodeRunLoopSucceededEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 class NodeRunLoopFailedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 1f6656535e..c6365d39c1 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -1,6 +1,5 @@ from collections.abc import Sequence from datetime import datetime -from typing import Optional from pydantic import Field @@ -12,9 +11,9 @@ from .base import GraphNodeEventBase class NodeRunStartedEvent(GraphNodeEventBase): node_title: str - predecessor_node_id: Optional[str] = None - parallel_mode_run_id: Optional[str] = None - agent_strategy: Optional[AgentNodeStrategyInit] = None + predecessor_node_id: str | None = None + parallel_mode_run_id: str | None = None + agent_strategy: AgentNodeStrategyInit | None = None start_at: datetime = Field(..., description="node start time") # FIXME(-LAN-): only for ToolNode diff --git a/api/core/workflow/node_events/agent.py b/api/core/workflow/node_events/agent.py index b89e4fe54e..e5fc46ddea 100644 --- a/api/core/workflow/node_events/agent.py +++ b/api/core/workflow/node_events/agent.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -14,5 +14,5 @@ class AgentLogEvent(NodeEventBase): error: str | None = Field(..., description="error") status: str = Field(..., description="status") data: Mapping[str, Any] = Field(..., description="data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata") + metadata: Mapping[str, Any] | None = Field(default=None, description="metadata") node_id: str = Field(..., description="node id") diff --git a/api/core/workflow/node_events/iteration.py b/api/core/workflow/node_events/iteration.py index 36c74ac9f1..db0b41a43a 100644 --- a/api/core/workflow/node_events/iteration.py +++ b/api/core/workflow/node_events/iteration.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -9,28 +9,28 @@ from .base import NodeEventBase class IterationStartedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class IterationNextEvent(NodeEventBase): index: int = Field(..., description="index") - pre_iteration_output: Optional[Any] = None + pre_iteration_output: Any = None class IterationSucceededEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 class IterationFailedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/node_events/loop.py b/api/core/workflow/node_events/loop.py index 5115fa9d3d..4e84fb0061 100644 --- a/api/core/workflow/node_events/loop.py +++ b/api/core/workflow/node_events/loop.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -9,28 +9,28 @@ from .base import NodeEventBase class LoopStartedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None + inputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None + predecessor_node_id: str | None = None class LoopNextEvent(NodeEventBase): index: int = Field(..., description="index") - pre_loop_output: Optional[Any] = None + pre_loop_output: Any = None class LoopSucceededEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 class LoopFailedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None + metadata: Mapping[str, Any] | None = None steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 3c7dcb8d66..cbd5af5013 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from packaging.version import Version from pydantic import ValidationError @@ -77,7 +77,7 @@ class AgentNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AgentNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -86,7 +86,7 @@ class AgentNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -414,7 +414,7 @@ class AgentNode(Node): icon = None return icon - def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]: + def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: # get conversation id conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID.value] diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py index d5955bdd7d..944f5f0b20 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exc.py @@ -1,6 +1,3 @@ -from typing import Optional - - class AgentNodeError(Exception): """Base exception for all agent node errors.""" @@ -12,7 +9,7 @@ class AgentNodeError(Exception): class AgentStrategyError(AgentNodeError): """Exception raised when there's an error with the agent strategy.""" - def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None): + def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None): self.strategy_name = strategy_name self.provider_name = provider_name super().__init__(message) @@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError): class AgentStrategyNotFoundError(AgentStrategyError): """Exception raised when the specified agent strategy is not found.""" - def __init__(self, strategy_name: str, provider_name: Optional[str] = None): + def __init__(self, strategy_name: str, provider_name: str | None = None): super().__init__( f"Agent strategy '{strategy_name}' not found" + (f" for provider '{provider_name}'" if provider_name else ""), @@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError): class AgentInvocationError(AgentNodeError): """Exception raised when there's an error invoking the agent.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError): class AgentParameterError(AgentNodeError): """Exception raised when there's an error with agent parameters.""" - def __init__(self, message: str, parameter_name: Optional[str] = None): + def __init__(self, message: str, parameter_name: str | None = None): self.parameter_name = parameter_name super().__init__(message) @@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError): class AgentVariableError(AgentNodeError): """Exception raised when there's an error with variables in the agent node.""" - def __init__(self, message: str, variable_name: Optional[str] = None): + def __init__(self, message: str, variable_name: str | None = None): self.variable_name = variable_name super().__init__(message) @@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError): class ToolFileError(AgentNodeError): """Exception raised when there's an error with a tool file.""" - def __init__(self, message: str, file_id: Optional[str] = None): + def __init__(self, message: str, file_id: str | None = None): self.file_id = file_id super().__init__(message) @@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError): class AgentMessageTransformError(AgentNodeError): """Exception raised when there's an error transforming agent messages.""" - def __init__(self, message: str, original_error: Optional[Exception] = None): + def __init__(self, message: str, original_error: Exception | None = None): self.original_error = original_error super().__init__(message) @@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError): class AgentModelError(AgentNodeError): """Exception raised when there's an error with the model used by the agent.""" - def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None): + def __init__(self, message: str, model_name: str | None = None, provider: str | None = None): self.model_name = model_name self.provider = provider super().__init__(message) @@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError): class AgentMemoryError(AgentNodeError): """Exception raised when there's an error with the agent's memory.""" - def __init__(self, message: str, conversation_id: Optional[str] = None): + def __init__(self, message: str, conversation_id: str | None = None): self.conversation_id = conversation_id super().__init__(message) @@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError): def __init__( self, message: str, - variable_name: Optional[str] = None, - expected_type: Optional[str] = None, - actual_type: Optional[str] = None, + variable_name: str | None = None, + expected_type: str | None = None, + actual_type: str | None = None, ): self.variable_name = variable_name self.expected_type = expected_type diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 4ef5c880c4..86174c7ea6 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.variables import ArrayFileSegment, FileSegment, Segment from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus @@ -20,7 +20,7 @@ class AnswerNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = AnswerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class AnswerNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 8c6ea0d59d..5aef9d79cf 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -2,7 +2,7 @@ import json from abc import ABC from collections.abc import Sequence from enum import StrEnum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, model_validator @@ -128,10 +128,10 @@ class DefaultValue(BaseModel): class BaseNodeData(ABC, BaseModel): title: str - desc: Optional[str] = None + desc: str | None = None version: str = "1" - error_strategy: Optional[ErrorStrategy] = None - default_value: Optional[list[DefaultValue]] = None + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None retry_config: RetryConfig = RetryConfig() @property @@ -142,7 +142,7 @@ class BaseNodeData(ABC, BaseModel): class BaseIterationNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseIterationState(BaseModel): @@ -157,7 +157,7 @@ class BaseIterationState(BaseModel): class BaseLoopNodeData(BaseNodeData): - start_node_id: Optional[str] = None + start_node_id: str | None = None class BaseLoopState(BaseModel): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index e5db872e3b..de6f4152c6 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -145,11 +145,10 @@ class Node: for event in result: # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] - event = self._dispatch(event) - - if not event.in_iteration_id and not event.in_loop_id: + yield self._dispatch(event) + elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] event.id = self._node_execution_id - yield event + yield event except Exception as e: logger.exception("Node %s failed to run", self._node_id) result = NodeRunResult( diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 8171686022..c87cbf9628 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import Any, Optional +from typing import Any, cast from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -30,7 +30,7 @@ class CodeNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = CodeNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -39,7 +39,7 @@ class CodeNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -49,7 +49,7 @@ class CodeNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: """ Get default config of node. :param filters: filter by node config parameters. @@ -57,7 +57,7 @@ class CodeNode(Node): """ code_language = CodeLanguage.PYTHON3 if filters: - code_language = filters.get("code_language", CodeLanguage.PYTHON3) + code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) @@ -154,7 +154,7 @@ class CodeNode(Node): def _transform_result( self, result: Mapping[str, Any], - output_schema: Optional[dict[str, CodeNodeData.Output]], + output_schema: dict[str, CodeNodeData.Output] | None, prefix: str = "", depth: int = 1, ): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index c8095e26e1..10a1c897e9 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal, Self from pydantic import AfterValidator, BaseModel @@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: Optional[dict[str, "CodeNodeData.Output"]] = None + children: dict[str, Self] | None = None class Dependency(BaseModel): name: str @@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData): code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[Dependency]] = None + dependencies: list[Dependency] | None = None diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 38213ea4b4..ae1061d72c 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any import chardet import docx @@ -49,7 +49,7 @@ class DocumentExtractorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = DocumentExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -58,7 +58,7 @@ class DocumentExtractorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index ca2aeddf3e..2bdfe4efce 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -18,7 +18,7 @@ class EndNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = EndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class EndNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 8d7ba25d47..5a7db6e0e6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,7 +1,7 @@ import mimetypes from collections.abc import Sequence from email.message import Message -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator @@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel): class HttpRequestNodeAuthorization(BaseModel): type: Literal["no-auth", "api-key"] - config: Optional[HttpRequestNodeAuthorizationConfig] = None + config: HttpRequestNodeAuthorizationConfig | None = None @field_validator("config", mode="before") @classmethod @@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData): authorization: HttpRequestNodeAuthorization headers: str params: str - body: Optional[HttpRequestNodeBody] = None - timeout: Optional[HttpRequestNodeTimeout] = None - ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + body: HttpRequestNodeBody | None = None + timeout: HttpRequestNodeTimeout | None = None + ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY class Response: @@ -183,7 +183,7 @@ class Response: return f"{(self.size / 1024 / 1024):.2f} MB" @property - def parsed_content_disposition(self) -> Optional[Message]: + def parsed_content_disposition(self) -> Message | None: content_disposition = self.headers.get("content-disposition", "") if content_disposition: msg = Message() diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 8186a002f8..826820a8e3 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,7 +1,7 @@ import logging import mimetypes from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from configs import dify_config from core.file import File, FileTransferMethod @@ -39,7 +39,7 @@ class HttpRequestNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = HttpRequestNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -48,7 +48,7 @@ class HttpRequestNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -58,7 +58,7 @@ class HttpRequestNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict[str, Any]] = None): + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": "http-request", "config": { diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index 67d6d6a886..b22bd6f508 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData): logical_operator: Literal["and", "or"] conditions: list[Condition] - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) - cases: Optional[list[Case]] = None + cases: list[Case] | None = None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 2149a9a05b..075f6f8444 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from typing_extensions import deprecated @@ -22,7 +22,7 @@ class IfElseNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IfElseNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -31,7 +31,7 @@ class IfElseNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 7a489dd725..ed4ab2c11c 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import Field @@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently + parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector is_parallel: bool = False # open the parallel mode or not @@ -39,7 +39,7 @@ class IterationState(BaseIterationState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any = None class MetaData(BaseIterationState.MetaData): """ @@ -48,7 +48,7 @@ class IterationState(BaseIterationState): iterator_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any: """ Get last output. """ @@ -56,7 +56,7 @@ class IterationState(BaseIterationState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any: """ Get current output. """ diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e092536d0a..10fe7473bb 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -82,7 +82,7 @@ class IterationNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: dict[str, object] | None = None): + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": "iteration", "config": { diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index c03e7257a2..80f39ccebc 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -20,7 +20,7 @@ class IterationStartNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class IterationStartNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index b71271abeb..460290f0ea 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field @@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel): """ top_k: int - score_threshold: Optional[float] = None + score_threshold: float | None = None reranking_mode: str = "reranking_model" reranking_enable: bool = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class SingleRetrievalConfig(BaseModel): @@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel): Metadata Filtering Condition. """ - logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[Condition] | None = Field(default=None, deprecated=True) class KnowledgeRetrievalNodeData(BaseNodeData): @@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData): query_variable_selector: list[str] dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] - multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None - single_retrieval_config: Optional[SingleRetrievalConfig] = None - metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: Optional[ModelConfig] = None - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None + multiple_retrieval_config: MultipleRetrievalConfig | None = None + single_retrieval_config: SingleRetrievalConfig | None = None + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" + metadata_model_config: ModelConfig | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d66b0cdf1a..7e4843b4c4 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ import re import time from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import Float, and_, func, or_, select, text from sqlalchemy import cast as sqlalchemy_cast @@ -119,7 +119,7 @@ class KnowledgeRetrievalNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = KnowledgeRetrievalNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -128,7 +128,7 @@ class KnowledgeRetrievalNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -410,7 +410,7 @@ class KnowledgeRetrievalNode(Node): def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: + ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: document_query = db.session.query(Document).where( Document.dataset_id.in_(dataset_ids), Document.indexing_status == "completed", @@ -568,7 +568,7 @@ class KnowledgeRetrievalNode(Node): return automatic_metadata_filters def _process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list[Any] + self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any] ) -> list[Any]: if value is None and condition not in ("empty", "not empty"): return filters diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index b604008656..7a31d69221 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import Any, Optional, TypeAlias, TypeVar +from typing import Any, TypeAlias, TypeVar from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -43,7 +43,7 @@ class ListOperatorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ListOperatorNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -52,7 +52,7 @@ class ListOperatorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 72f83eb25b..fe6f2290aa 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -18,7 +18,7 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): enabled: bool - variable_selector: Optional[list[str]] = None + variable_selector: list[str] | None = None class VisionConfigOptions(BaseModel): @@ -51,18 +51,18 @@ class PromptConfig(BaseModel): class LLMNodeChatModelMessage(ChatModelMessage): text: str = "" - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: Optional[str] = None + jinja2_text: str | None = None class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: Optional[MemoryConfig] = None + memory: MemoryConfig | None = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: Mapping[str, Any] | None = None diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index af22b8588c..ad969cdad1 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from sqlalchemy import select, update from sqlalchemy.orm import Session @@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc def fetch_memory( - variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance -) -> Optional[TokenBufferMemory]: + variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance +) -> TokenBufferMemory | None: if not node_data_memory: return None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 175581d95f..a0f4836e82 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,7 +4,7 @@ import json import logging import re from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -139,7 +139,7 @@ class LLMNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LLMNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -148,7 +148,7 @@ class LLMNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -354,10 +354,10 @@ class LLMNode(Node): node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], - stop: Optional[Sequence[str]] = None, + stop: Sequence[str] | None = None, user_id: str, structured_output_enabled: bool, - structured_output: Optional[Mapping[str, Any]] = None, + structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, file_outputs: list["File"], node_id: str, @@ -716,7 +716,7 @@ class LLMNode(Node): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, - ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): @@ -959,7 +959,7 @@ class LLMNode(Node): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "type": "llm", "config": { @@ -987,7 +987,7 @@ class LLMNode(Node): def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, @@ -1175,7 +1175,7 @@ class LLMNode(Node): def _combine_message_content_with_role( - *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole ): match role: case PromptMessageRole.USER: @@ -1281,7 +1281,7 @@ def _handle_memory_completion_mode( def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, - context: Optional[str], + context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ) -> Sequence[PromptMessage]: diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 6f6939810b..90881ba3b7 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator @@ -41,7 +41,7 @@ class LoopNodeData(BaseLoopNodeData): loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] - loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData]) + loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) outputs: dict[str, Any] = Field(default_factory=dict) @field_validator("outputs", mode="before") @@ -74,7 +74,7 @@ class LoopState(BaseLoopState): """ outputs: list[Any] = Field(default_factory=list) - current_output: Optional[Any] = None + current_output: Any = None class MetaData(BaseLoopState.MetaData): """ @@ -83,7 +83,7 @@ class LoopState(BaseLoopState): loop_length: int - def get_last_output(self) -> Optional[Any]: + def get_last_output(self) -> Any: """ Get last output. """ @@ -91,7 +91,7 @@ class LoopState(BaseLoopState): return self.outputs[-1] return None - def get_current_output(self) -> Optional[Any]: + def get_current_output(self) -> Any: """ Get current output. """ diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 8b1b5b424d..38aef06d24 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -20,7 +20,7 @@ class LoopEndNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopEndNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class LoopEndNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index ba26322cc3..2217bc205e 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, cast from core.variables import Segment, SegmentType from core.workflow.enums import ( @@ -51,7 +51,7 @@ class LoopNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -60,7 +60,7 @@ class LoopNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 9f3febe9b0..e777a8cbe9 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -20,7 +20,7 @@ class LoopStartNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = LoopStartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class LoopStartNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 4c0b14b2d7..4e3819c4cf 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, @@ -48,7 +48,7 @@ class ParameterConfig(BaseModel): name: str type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: Optional[list[str]] = None + options: list[str] | None = None description: str required: bool @@ -86,8 +86,8 @@ class ParameterExtractorNodeData(BaseNodeData): model: ModelConfig query: list[str] parameters: list[ParameterConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None reasoning_mode: Literal["function_call", "prompt"] vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 3f79006836..875a0598e0 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -3,7 +3,7 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import File @@ -96,7 +96,7 @@ class ParameterExtractorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = ParameterExtractorNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -105,7 +105,7 @@ class ParameterExtractorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -114,11 +114,11 @@ class ParameterExtractorNode(Node): def get_base_node_data(self) -> BaseNodeData: return self._node_data - _model_instance: Optional[ModelInstance] = None - _model_config: Optional[ModelConfigWithCredentialsEntity] = None + _model_instance: ModelInstance | None = None + _model_config: ModelConfigWithCredentialsEntity | None = None @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { "model": { "prompt_templates": { @@ -293,7 +293,7 @@ class ParameterExtractorNode(Node): prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], stop: list[str], - ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: + ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=node_data_model.completion_params, @@ -323,9 +323,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. @@ -405,9 +405,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate prompt engineering prompt. @@ -443,9 +443,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate completion prompt. @@ -477,9 +477,9 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, files: Sequence[File], - vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None, + vision_detail: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ Generate chat prompt. @@ -651,7 +651,7 @@ class ParameterExtractorNode(Node): return transformed_result - def _extract_complete_json_response(self, result: str) -> Optional[dict]: + def _extract_complete_json_response(self, result: str) -> dict | None: """ Extract complete json response. """ @@ -666,7 +666,7 @@ class ParameterExtractorNode(Node): logger.info("extra error: %s", result) return None - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: + def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: """ Extract json from tool call. """ @@ -705,7 +705,7 @@ class ParameterExtractorNode(Node): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ) -> list[ChatModelMessage]: model_mode = ModelMode(node_data.model.mode) @@ -732,7 +732,7 @@ class ParameterExtractorNode(Node): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) @@ -768,7 +768,7 @@ class ParameterExtractorNode(Node): query: str, variable_pool: VariablePool, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 6248df0edf..edde30708a 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData): query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] - instruction: Optional[str] = None - memory: Optional[MemoryConfig] = None + instruction: str | None = None + memory: MemoryConfig | None = None vision: VisionConfig = Field(default_factory=VisionConfig) @property diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 929216652e..483cfff574 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -80,7 +80,7 @@ class QuestionClassifierNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = QuestionClassifierNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -89,7 +89,7 @@ class QuestionClassifierNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -271,7 +271,7 @@ class QuestionClassifierNode(Node): return variable_mapping @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: """ Get default config of node. :param filters: filter by node config parameters (not used in this implementation). @@ -285,7 +285,7 @@ class QuestionClassifierNode(Node): node_data: QuestionClassifierNodeData, query: str, model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], + context: str | None, ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) @@ -328,7 +328,7 @@ class QuestionClassifierNode(Node): self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], + memory: TokenBufferMemory | None, max_token_limit: int = 2000, ): model_mode = ModelMode(node_data.model.mode) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 608f6b11cc..2f33c54128 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus @@ -18,7 +18,7 @@ class StartNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = StartNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -27,7 +27,7 @@ class StartNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 9039476871..cf05ef253a 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,6 +1,6 @@ import os from collections.abc import Mapping, Sequence -from typing import Any, Optional +from typing import Any from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus @@ -20,7 +20,7 @@ class TemplateTransformNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = TemplateTransformNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -29,7 +29,7 @@ class TemplateTransformNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -39,7 +39,7 @@ class TemplateTransformNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: """ Get default config of node. :param filters: filter by node config parameters. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 6829d649d3..ecce28b2ad 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from sqlalchemy import select from sqlalchemy.orm import Session @@ -470,7 +470,7 @@ class ToolNode(Node): return result - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -479,7 +479,7 @@ class ToolNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index f4577d7573..13dbc5dbe6 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import BaseModel from core.variables.types import SegmentType @@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData): type: str = "variable-assigner" output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None + advanced_settings: AdvancedSettings | None = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index d2627d9d3b..be00d55937 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.variables.segments import Segment from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus @@ -17,7 +17,7 @@ class VariableAggregatorNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData(**data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -26,7 +26,7 @@ class VariableAggregatorNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 5eb9938b9e..c2a9ecd7fb 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias from core.variables import SegmentType, Variable from core.variables.segments import BooleanSegment @@ -33,7 +33,7 @@ class VariableAssignerNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -42,7 +42,7 @@ class VariableAssignerNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/core/workflow/nodes/variable_assigner/v2/entities.py index bdb8716b8a..2955730289 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/entities.py +++ b/api/core/workflow/nodes/variable_assigner/v2/entities.py @@ -18,7 +18,7 @@ class VariableOperationItem(BaseModel): # 2. For VARIABLE input_type: Initially contains the selector of the source variable. # 3. During the variable updating procedure: The `value` field is reassigned to hold # the resolved actual value that will be applied to the target variable. - value: Any | None = None + value: Any = None class VariableAssignerNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index e7833aa46f..a89055fd66 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -60,7 +60,7 @@ class VariableAssignerNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = VariableAssignerNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -69,7 +69,7 @@ class VariableAssignerNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index aba8f9ed20..3c1c28b3f3 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Optional, Protocol +from typing import Literal, Protocol from core.workflow.entities import WorkflowNodeExecution @@ -10,7 +10,7 @@ class OrderConfig: """Configuration for ordering NodeExecution instances.""" order_by: list[str] - order_direction: Optional[Literal["asc", "desc"]] = None + order_direction: Literal["asc", "desc"] | None = None class WorkflowNodeExecutionRepository(Protocol): @@ -42,7 +42,7 @@ class WorkflowNodeExecutionRepository(Protocol): def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ Retrieve all NodeExecution instances for a specific workflow run. diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index cbf4cfd136..a477733bda 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, Union +from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -85,9 +85,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -112,9 +112,9 @@ class WorkflowCycleManager: total_steps: int, outputs: Mapping[str, Any] | None = None, exceptions_count: int = 0, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, - external_trace_id: Optional[str] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) @@ -140,10 +140,10 @@ class WorkflowCycleManager: total_steps: int, status: WorkflowExecutionStatus, error_message: str, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None, + conversation_id: str | None = None, + trace_manager: TraceQueueManager | None = None, exceptions_count: int = 0, - external_trace_id: Optional[str] = None, + external_trace_id: str | None = None, ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) now = naive_utc_now() @@ -295,9 +295,9 @@ class WorkflowCycleManager: total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, - error_message: Optional[str] = None, + error_message: str | None = None, exceptions_count: int = 0, - finished_at: Optional[datetime] = None, + finished_at: datetime | None = None, ): """Update workflow execution with completion data.""" execution.status = status @@ -311,10 +311,10 @@ class WorkflowCycleManager: def _add_trace_task_if_needed( self, - trace_manager: Optional[TraceQueueManager], + trace_manager: TraceQueueManager | None, workflow_execution: WorkflowExecution, - conversation_id: Optional[str], - external_trace_id: Optional[str], + conversation_id: str | None, + external_trace_id: str | None, ): """Add trace task if trace manager is provided.""" if trace_manager: @@ -401,7 +401,7 @@ class WorkflowCycleManager: QueueNodeExceptionEvent, ], status: WorkflowNodeExecutionStatus, - error: Optional[str] = None, + error: str | None = None, handle_special_values: bool = False, ): """Update node execution with completion data.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 466e537a1a..99f969a8be 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional +from typing import Any from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError @@ -43,7 +43,7 @@ class WorkflowEntry: call_depth: int, variable_pool: VariablePool, graph_runtime_state: GraphRuntimeState, - command_channel: Optional[CommandChannel] = None, + command_channel: CommandChannel | None = None, ) -> None: """ Init workflow entry @@ -341,7 +341,7 @@ class WorkflowEntry: raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod - def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: # NOTE(QuantumGhost): Avoid using this function in new code. # Keep values structured as long as possible and only convert to dict # immediately before serialization (e.g., JSON serialization) to maintain diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 05cd9610ef..da96526121 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -543,12 +543,12 @@ class WorkflowService: # This will prevent validation errors from breaking the workflow return [] - def get_default_block_configs(self) -> list[dict]: + def get_default_block_configs(self) -> Sequence[Mapping[str, object]]: """ Get default block configs """ # return default block config - default_block_configs = [] + default_block_configs: list[Mapping[str, object]] = [] for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): node_class = node_class_mapping[LATEST_VERSION] default_config = node_class.get_default_config() From a923ab1ab8cda1613c96672731abcaeab124ba33 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 11 Sep 2025 15:01:16 +0800 Subject: [PATCH 16/31] fix: type errors Signed-off-by: -LAN- --- .../remote_settings_sources/apollo/utils.py | 2 +- .../common/workflow_response_converter.py | 8 +- api/core/app/apps/workflow_app_runner.py | 2 - api/core/app/entities/queue_entities.py | 164 ++++++++--------- api/core/app/entities/task_entities.py | 170 +++++++++--------- .../base/tts/app_generator_tts_publisher.py | 6 +- api/core/mcp/types.py | 2 +- api/core/rag/entities/citation_metadata.py | 36 ++-- api/core/tools/entities/api_entities.py | 37 ++-- api/core/tools/entities/tool_entities.py | 54 +++--- api/core/tools/mcp_tool/provider.py | 1 - api/core/tools/tool_manager.py | 21 ++- api/core/workflow/graph_events/agent.py | 2 +- api/core/workflow/graph_events/graph.py | 8 +- api/core/workflow/graph_events/iteration.py | 16 +- api/core/workflow/graph_events/loop.py | 16 +- api/core/workflow/graph_events/node.py | 1 - api/core/workflow/node_events/agent.py | 2 +- api/core/workflow/node_events/iteration.py | 16 +- api/core/workflow/node_events/loop.py | 16 +- api/core/workflow/nodes/base/node.py | 11 +- api/models/model.py | 2 +- api/models/tools.py | 5 +- api/services/tools/tools_transform_service.py | 27 +-- .../core/tools/workflow_as_tool/test_tool.py | 1 - 25 files changed, 310 insertions(+), 316 deletions(-) diff --git a/api/configs/remote_settings_sources/apollo/utils.py b/api/configs/remote_settings_sources/apollo/utils.py index cff187954d..40731448a0 100644 --- a/api/configs/remote_settings_sources/apollo/utils.py +++ b/api/configs/remote_settings_sources/apollo/utils.py @@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str: # Returns whether the obtained value is obtained, and None if it does not -def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None: +def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any: if namespace_cache: kv_data = namespace_cache.get(CONFIGURATIONS) if kv_data is None: diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index e4796dd3d0..6dd739429f 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -319,7 +319,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_title, - outputs=json_converter.to_json_encodable(event.outputs), + outputs=json_converter.to_json_encodable(event.outputs) or {}, created_at=int(time.time()), extras={}, inputs=event.inputs or {}, @@ -328,7 +328,7 @@ class WorkflowResponseConverter: else WorkflowNodeExecutionStatus.FAILED, error=None, elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), - total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, + total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)), execution_metadata=event.metadata, finished_at=int(time.time()), steps=event.steps, @@ -395,7 +395,7 @@ class WorkflowResponseConverter: node_id=event.node_id, node_type=event.node_type.value, title=event.node_title, - outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs), + outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs) or {}, created_at=int(time.time()), extras={}, inputs=event.inputs or {}, @@ -404,7 +404,7 @@ class WorkflowResponseConverter: else WorkflowNodeExecutionStatus.FAILED, error=None, elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), - total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, + total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)), execution_metadata=event.metadata, finished_at=int(time.time()), steps=event.steps, diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 9b104cdace..056e03fa14 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -384,7 +384,6 @@ class WorkflowBasedAppRunner: predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, - parallel_mode_run_id=event.parallel_mode_run_id, inputs=inputs, process_data=process_data, outputs=outputs, @@ -406,7 +405,6 @@ class WorkflowBasedAppRunner: predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, - parallel_mode_run_id=event.parallel_mode_run_id, agent_strategy=event.agent_strategy, provider_type=event.provider_type, provider_id=event.provider_id, diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index fc2991f1ea..34bacfbd6c 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,9 +1,9 @@ from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum, StrEnum -from typing import Any, Optional +from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata @@ -79,9 +79,9 @@ class QueueIterationStartEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + predecessor_node_id: str | None = None + metadata: Mapping[str, object] = Field(default_factory=dict) class QueueIterationNextEvent(AppQueueEvent): @@ -97,7 +97,7 @@ class QueueIterationNextEvent(AppQueueEvent): node_type: NodeType node_title: str node_run_index: int - output: Optional[Any] = None # output for the current iteration + output: Any = None # output for the current iteration class QueueIterationCompletedEvent(AppQueueEvent): @@ -114,12 +114,12 @@ class QueueIterationCompletedEvent(AppQueueEvent): start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueLoopStartEvent(AppQueueEvent): @@ -132,20 +132,20 @@ class QueueLoopStartEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + predecessor_node_id: str | None = None + metadata: Mapping[str, object] = Field(default_factory=dict) class QueueLoopNextEvent(AppQueueEvent): @@ -160,18 +160,18 @@ class QueueLoopNextEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - parallel_mode_run_id: Optional[str] = None + parallel_mode_run_id: str | None = None """iteration run in parallel mode run id""" node_run_index: int - output: Optional[Any] = None # output for the current loop + output: Any = None # output for the current loop class QueueLoopCompletedEvent(AppQueueEvent): @@ -185,23 +185,23 @@ class QueueLoopCompletedEvent(AppQueueEvent): node_id: str node_type: NodeType node_title: str - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 - error: Optional[str] = None + error: str | None = None class QueueTextChunkEvent(AppQueueEvent): @@ -211,11 +211,11 @@ class QueueTextChunkEvent(AppQueueEvent): event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None """from variable selector""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -252,9 +252,9 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: Sequence[RetrievalSourceMetadata] - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" @@ -273,7 +273,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.MESSAGE_END - llm_result: Optional[LLMResult] = None + llm_result: LLMResult | None = None class QueueAdvancedChatMessageEndEvent(AppQueueEvent): @@ -299,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED - outputs: Optional[dict[str, Any]] = None + outputs: Mapping[str, object] = Field(default_factory=dict) class QueueWorkflowFailedEvent(AppQueueEvent): @@ -319,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent): event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED exceptions_count: int - outputs: Optional[dict[str, Any]] = None + outputs: Mapping[str, object] = Field(default_factory=dict) class QueueNodeStartedEvent(AppQueueEvent): @@ -334,16 +334,16 @@ class QueueNodeStartedEvent(AppQueueEvent): node_title: str node_type: NodeType node_run_index: int = 1 # FIXME(-LAN-): may not used - predecessor_node_id: Optional[str] = None - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - in_iteration_id: Optional[str] = None - in_loop_id: Optional[str] = None + predecessor_node_id: str | None = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + in_iteration_id: str | None = None + in_loop_id: str | None = None start_at: datetime - parallel_mode_run_id: Optional[str] = None - agent_strategy: Optional[AgentNodeStrategyInit] = None + parallel_mode_run_id: str | None = None + agent_strategy: AgentNodeStrategyInit | None = None # FIXME(-LAN-): only for ToolNode, need to refactor provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType @@ -360,26 +360,26 @@ class QueueNodeSucceededEvent(AppQueueEvent): node_execution_id: str node_id: str node_type: NodeType - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + process_data: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None - error: Optional[str] = None + error: str | None = None class QueueAgentLogEvent(AppQueueEvent): @@ -395,7 +395,7 @@ class QueueAgentLogEvent(AppQueueEvent): error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, object] = Field(default_factory=dict) node_id: str @@ -404,10 +404,10 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent): event: QueueEvent = QueueEvent.RETRY - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + process_data: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str retry_index: int # retry index @@ -423,24 +423,24 @@ class QueueNodeExceptionEvent(AppQueueEvent): node_execution_id: str node_id: str node_type: NodeType - parallel_id: Optional[str] = None + parallel_id: str | None = None """parallel id if node is in parallel""" - parallel_start_node_id: Optional[str] = None + parallel_start_node_id: str | None = None """parallel start node id if node is in parallel""" - parent_parallel_id: Optional[str] = None + parent_parallel_id: str | None = None """parent parallel id if node is in parallel""" - parent_parallel_start_node_id: Optional[str] = None + parent_parallel_start_node_id: str | None = None """parent parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + process_data: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -455,17 +455,17 @@ class QueueNodeFailedEvent(AppQueueEvent): node_execution_id: str node_id: str node_type: NodeType - parallel_id: Optional[str] = None - in_iteration_id: Optional[str] = None + parallel_id: str | None = None + in_iteration_id: str | None = None """iteration id if node is in iteration""" - in_loop_id: Optional[str] = None + in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + inputs: Mapping[str, object] = Field(default_factory=dict) + process_data: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None error: str @@ -494,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.ERROR - error: Optional[Any] = None + error: Any = None class QueuePingEvent(AppQueueEvent): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 86f05e9624..59e4ffb351 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -108,7 +108,7 @@ class MessageStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE id: str answer: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None class MessageAudioStreamResponse(StreamResponse): @@ -136,8 +136,8 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str - metadata: dict = Field(default_factory=dict) - files: Optional[Sequence[Mapping[str, Any]]] = None + metadata: Mapping[str, object] = Field(default_factory=dict) + files: Sequence[Mapping[str, Any]] | None = None class MessageFileStreamResponse(StreamResponse): @@ -170,12 +170,12 @@ class AgentThoughtStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int - thought: Optional[str] = None - observation: Optional[str] = None - tool: Optional[str] = None - tool_labels: Optional[dict] = None - tool_input: Optional[str] = None - message_files: Optional[list[str]] = None + thought: str | None = None + observation: str | None = None + tool: str | None = None + tool_labels: Mapping[str, object] = Field(default_factory=dict) + tool_input: str | None = None + message_files: list[str] | None = None class AgentMessageStreamResponse(StreamResponse): @@ -221,16 +221,16 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int - created_by: Optional[dict] = None + created_by: Mapping[str, object] = Field(default_factory=dict) created_at: int finished_at: int - exceptions_count: Optional[int] = 0 - files: Optional[Sequence[Mapping[str, Any]]] = [] + exceptions_count: int | None = 0 + files: Sequence[Mapping[str, Any]] | None = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -252,18 +252,18 @@ class NodeStartStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None created_at: int - extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None - parallel_run_id: Optional[str] = None - agent_strategy: Optional[AgentNodeStrategyInit] = None + extras: dict[str, object] = Field(default_factory=dict) + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None + parallel_run_id: str | None = None + agent_strategy: AgentNodeStrategyInit | None = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -309,23 +309,23 @@ class NodeFinishStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -378,23 +378,23 @@ class NodeRetryStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + outputs: Mapping[str, Any] | None = None status: str - error: Optional[str] = None + error: str | None = None elapsed_time: float - execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None created_at: int finished_at: int - files: Optional[Sequence[Mapping[str, Any]]] = [] - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parent_parallel_id: Optional[str] = None - parent_parallel_start_node_id: Optional[str] = None - iteration_id: Optional[str] = None - loop_id: Optional[str] = None + files: Sequence[Mapping[str, Any]] | None = [] + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parent_parallel_id: str | None = None + parent_parallel_start_node_id: str | None = None + iteration_id: str | None = None + loop_id: str | None = None retry_index: int = 0 event: StreamEvent = StreamEvent.NODE_RETRY @@ -449,9 +449,9 @@ class IterationNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = Field(default_factory=dict) - metadata: Mapping = {} - inputs: Mapping = {} + extras: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) + inputs: Mapping[str, object] = Field(default_factory=dict) event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -474,7 +474,7 @@ class IterationNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - extras: dict = Field(default_factory=dict) + extras: Mapping[str, object] = Field(default_factory=dict) event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -495,15 +495,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping[str, object] = Field(default_factory=dict) created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: Mapping[str, object] = Field(default_factory=dict) + inputs: Mapping[str, object] = Field(default_factory=dict) status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping[str, object] = Field(default_factory=dict) finished_at: int steps: int @@ -527,11 +527,11 @@ class LoopNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = Field(default_factory=dict) - metadata: Mapping = {} - inputs: Mapping = {} - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + extras: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) + inputs: Mapping[str, object] = Field(default_factory=dict) + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_STARTED workflow_run_id: str @@ -554,11 +554,11 @@ class LoopNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - pre_loop_output: Optional[Any] = None - extras: dict = Field(default_factory=dict) - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None - parallel_mode_run_id: Optional[str] = None + pre_loop_output: Any = None + extras: Mapping[str, object] = Field(default_factory=dict) + parallel_id: str | None = None + parallel_start_node_id: str | None = None + parallel_mode_run_id: str | None = None event: StreamEvent = StreamEvent.LOOP_NEXT workflow_run_id: str @@ -579,19 +579,19 @@ class LoopNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping[str, object] = Field(default_factory=dict) created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: Mapping[str, object] = Field(default_factory=dict) + inputs: Mapping[str, object] = Field(default_factory=dict) status: WorkflowNodeExecutionStatus - error: Optional[str] = None + error: str | None = None elapsed_time: float total_tokens: int - execution_metadata: Optional[Mapping] = None + execution_metadata: Mapping[str, object] = Field(default_factory=dict) finished_at: int steps: int - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_COMPLETED workflow_run_id: str @@ -609,7 +609,7 @@ class TextChunkStreamResponse(StreamResponse): """ text: str - from_variable_selector: Optional[list[str]] = None + from_variable_selector: list[str] | None = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -671,7 +671,7 @@ class WorkflowAppStreamResponse(AppStreamResponse): WorkflowAppStreamResponse entity """ - workflow_run_id: Optional[str] = None + workflow_run_id: str | None = None class AppBlockingResponse(BaseModel): @@ -697,7 +697,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): conversation_id: str message_id: str answer: str - metadata: dict = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) created_at: int data: Data @@ -717,7 +717,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse): mode: str message_id: str answer: str - metadata: dict = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) created_at: int data: Data @@ -736,8 +736,8 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str status: str - outputs: Optional[Mapping[str, Any]] = None - error: Optional[str] = None + outputs: Mapping[str, Any] | None = None + error: str | None = None elapsed_time: float total_tokens: int total_steps: int @@ -765,7 +765,7 @@ class AgentLogStreamResponse(StreamResponse): error: str | None = None status: str data: Mapping[str, Any] - metadata: Optional[Mapping[str, Any]] = None + metadata: Mapping[str, object] = Field(default_factory=dict) node_id: str event: StreamEvent = StreamEvent.AGENT_LOG diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 89190c36cc..1e60e14e34 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -110,7 +110,9 @@ class AppGeneratorTTSPublisher: elif isinstance(message.event, QueueNodeSucceededEvent): if message.event.outputs is None: continue - self.msg_text += message.event.outputs.get("output", "") + output = message.event.outputs.get("output", "") + if isinstance(output, str): + self.msg_text += output self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.max_sentence, 7): @@ -120,7 +122,7 @@ class AppGeneratorTTSPublisher: _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice ) future_queue.put(futures_result) - if text_tmp: + if isinstance(text_tmp, str): self.msg_text = text_tmp else: self.msg_text = "" diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index a2c3157b3b..e939edade5 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -161,7 +161,7 @@ class ErrorData(BaseModel): sentence. """ - data: Any | None = None + data: Any = None """ Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.). diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index 00120425c9..aca879df7d 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -1,23 +1,23 @@ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel class RetrievalSourceMetadata(BaseModel): - position: Optional[int] = None - dataset_id: Optional[str] = None - dataset_name: Optional[str] = None - document_id: Optional[str] = None - document_name: Optional[str] = None - data_source_type: Optional[str] = None - segment_id: Optional[str] = None - retriever_from: Optional[str] = None - score: Optional[float] = None - hit_count: Optional[int] = None - word_count: Optional[int] = None - segment_position: Optional[int] = None - index_node_hash: Optional[str] = None - content: Optional[str] = None - page: Optional[int] = None - doc_metadata: Optional[dict[str, Any]] = None - title: Optional[str] = None + position: int | None = None + dataset_id: str | None = None + dataset_name: str | None = None + document_id: str | None = None + document_name: str | None = None + data_source_type: str | None = None + segment_id: str | None = None + retriever_from: str | None = None + score: float | None = None + hit_count: int | None = None + word_count: int | None = None + segment_position: int | None = None + index_node_hash: str | None = None + content: str | None = None + page: int | None = None + doc_metadata: dict[str, Any] | None = None + title: str | None = None diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ca3be26ff9..00c4ab9dd7 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from datetime import datetime -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -14,12 +15,12 @@ class ToolApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[ToolParameter]] = None + parameters: list[ToolParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: Optional[dict] = None + output_schema: Mapping[str, object] = Field(default_factory=dict) -ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]] +ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None class ToolProviderApiEntity(BaseModel): @@ -27,26 +28,26 @@ class ToolProviderApiEntity(BaseModel): author: str name: str # identifier description: I18nObject - icon: str | dict - icon_dark: Optional[str | dict] = Field(default=None, description="The dark icon of the tool") + icon: str | Mapping[str, str] + icon_dark: str | Mapping[str, str] = "" label: I18nObject # label type: ToolProviderType - masked_credentials: Optional[dict] = None - original_credentials: Optional[dict] = None + masked_credentials: Mapping[str, object] = Field(default_factory=dict) + original_credentials: Mapping[str, object] = Field(default_factory=dict) is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") - tools: list[ToolApiEntity] = Field(default_factory=list) + plugin_id: str | None = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool") + tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity]) labels: list[str] = Field(default_factory=list) # MCP - server_url: Optional[str] = Field(default="", description="The server url of the tool") + server_url: str | None = Field(default="", description="The server url of the tool") updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) - server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool") - timeout: Optional[float] = Field(default=30.0, description="The timeout of the MCP tool") - sse_read_timeout: Optional[float] = Field(default=300.0, description="The SSE read timeout of the MCP tool") - masked_headers: Optional[dict[str, str]] = Field(default=None, description="The masked headers of the MCP tool") - original_headers: Optional[dict[str, str]] = Field(default=None, description="The original headers of the MCP tool") + server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool") + timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool") + sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool") + masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool") + original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool") @field_validator("tools", mode="before") @classmethod @@ -105,7 +106,7 @@ class ToolProviderCredentialApiEntity(BaseModel): is_default: bool = Field( default=False, description="Whether the credential is the default credential for the provider in the workspace" ) - credentials: dict = Field(description="The credentials of the provider") + credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict) class ToolProviderCredentialInfoApiEntity(BaseModel): diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 66304b30a5..077949906c 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -3,7 +3,7 @@ import contextlib import enum from collections.abc import Mapping from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator @@ -183,11 +183,11 @@ class ToolInvokeMessage(BaseModel): id: str label: str = Field(..., description="The label of the log") - parent_id: Optional[str] = Field(default=None, description="Leave empty for root log") - error: Optional[str] = Field(default=None, description="The error message") + parent_id: str | None = Field(default=None, description="Leave empty for root log") + error: str | None = Field(default=None, description="The error message") status: LogStatus = Field(..., description="The status of the log") data: Mapping[str, Any] = Field(..., description="Detailed log data") - metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log") class RetrieverResourceMessage(BaseModel): retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") @@ -242,7 +242,7 @@ class ToolInvokeMessage(BaseModel): class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - file_var: Optional[dict[str, Any]] = None + file_var: dict[str, Any] | None = None class ToolParameter(PluginParameter): @@ -286,11 +286,11 @@ class ToolParameter(PluginParameter): LLM = "llm" # will be set by LLM type: ToolParameterType = Field(..., description="The type of the parameter") - human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") + human_description: I18nObject | None = Field(default=None, description="The description presented to the user") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") - llm_description: Optional[str] = None + llm_description: str | None = None # MCP object and array type parameters use this field to store the schema - input_schema: Optional[dict] = None + input_schema: dict | None = None @classmethod def get_simple_instance( @@ -299,7 +299,7 @@ class ToolParameter(PluginParameter): llm_description: str, typ: ToolParameterType, required: bool, - options: Optional[list[str]] = None, + options: list[str] | None = None, ) -> "ToolParameter": """ get a simple tool parameter @@ -340,9 +340,9 @@ class ToolProviderIdentity(BaseModel): name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") - icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool") + icon_dark: str | None = Field(default=None, description="The dark icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( + tags: list[ToolLabelEnum] | None = Field( default=[], description="The tags of the tool", ) @@ -353,7 +353,7 @@ class ToolIdentity(BaseModel): name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") provider: str = Field(..., description="The provider of the tool") - icon: Optional[str] = None + icon: str | None = None class ToolDescription(BaseModel): @@ -363,9 +363,9 @@ class ToolDescription(BaseModel): class ToolEntity(BaseModel): identity: ToolIdentity - parameters: list[ToolParameter] = Field(default_factory=list) - description: Optional[ToolDescription] = None - output_schema: Optional[dict] = None + parameters: list[ToolParameter] = Field(default_factory=list[ToolParameter]) + description: ToolDescription | None = None + output_schema: Mapping[str, object] = Field(default_factory=dict) has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") # pydantic configs @@ -378,21 +378,23 @@ class ToolEntity(BaseModel): class OAuthSchema(BaseModel): - client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") + client_schema: list[ProviderConfig] = Field( + default_factory=list[ProviderConfig], description="The schema of the OAuth client" + ) credentials_schema: list[ProviderConfig] = Field( - default_factory=list, description="The schema of the OAuth credentials" + default_factory=list[ProviderConfig], description="The schema of the OAuth credentials" ) class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity - plugin_id: Optional[str] = None - credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = None + plugin_id: str | None = None + credentials_schema: list[ProviderConfig] = Field(default_factory=list[ProviderConfig]) + oauth_schema: OAuthSchema | None = None class ToolProviderEntityWithPlugin(ToolProviderEntity): - tools: list[ToolEntity] = Field(default_factory=list) + tools: list[ToolEntity] = Field(default_factory=list[ToolEntity]) class WorkflowToolParameterConfiguration(BaseModel): @@ -411,8 +413,8 @@ class ToolInvokeMeta(BaseModel): """ time_cost: float = Field(..., description="The time cost of the tool invoke") - error: Optional[str] = None - tool_config: Optional[dict] = None + error: str | None = None + tool_config: dict | None = None @classmethod def empty(cls) -> "ToolInvokeMeta": @@ -464,11 +466,11 @@ class ToolSelector(BaseModel): type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") required: bool = Field(..., description="Whether the parameter is required") description: str = Field(..., description="The description of the parameter") - default: Optional[Union[int, float, str]] = None - options: Optional[list[PluginParameterOption]] = None + default: Union[int, float, str] | None = None + options: list[PluginParameterOption] | None = None provider_id: str = Field(..., description="The id of the provider") - credential_id: Optional[str] = Field(default=None, description="The id of the credential") + credential_id: str | None = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 5f6eb045ab..1b9c631f81 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -72,7 +72,6 @@ class MCPToolProviderController(ToolProviderController): ), llm=remote_mcp_tool.description or "", ), - output_schema=None, has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0, ) for remote_mcp_tool in remote_mcp_tools diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 5c836cfcd2..766e0568c4 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -886,7 +886,7 @@ class ToolManager: ) @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str): + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) @@ -897,13 +897,13 @@ class ToolManager: if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - icon: dict = json.loads(workflow_provider.icon) + icon = json.loads(workflow_provider.icon) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str): + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) @@ -914,13 +914,13 @@ class ToolManager: if api_provider is None: raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - icon: dict = json.loads(api_provider.icon) + icon = json.loads(api_provider.icon) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str: try: mcp_provider: MCPToolProvider | None = ( db.session.query(MCPToolProvider) @@ -941,7 +941,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> Union[str, dict[str, Any]]: + ) -> str | Mapping[str, str]: """ get the tool icon @@ -966,11 +966,10 @@ class ToolManager: return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) elif provider_type == ToolProviderType.PLUGIN: provider = ToolManager.get_plugin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} raise ValueError(f"plugin provider {provider_id} not found") elif provider_type == ToolProviderType.MCP: return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) diff --git a/api/core/workflow/graph_events/agent.py b/api/core/workflow/graph_events/agent.py index 67d94d25eb..759fe3a71c 100644 --- a/api/core/workflow/graph_events/agent.py +++ b/api/core/workflow/graph_events/agent.py @@ -14,4 +14,4 @@ class NodeRunAgentLogEvent(GraphAgentNodeEventBase): error: str | None = Field(..., description="error") status: str = Field(..., description="status") data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, Any] | None = Field(default=None, description="metadata") + metadata: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/core/workflow/graph_events/graph.py b/api/core/workflow/graph_events/graph.py index 4f7e886519..5d13833faa 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/core/workflow/graph_events/graph.py @@ -1,5 +1,3 @@ -from typing import Any - from pydantic import Field from core.workflow.graph_events import BaseGraphEvent @@ -10,7 +8,7 @@ class GraphRunStartedEvent(BaseGraphEvent): class GraphRunSucceededEvent(BaseGraphEvent): - outputs: dict[str, Any] | None = None + outputs: dict[str, object] = Field(default_factory=dict) class GraphRunFailedEvent(BaseGraphEvent): @@ -20,11 +18,11 @@ class GraphRunFailedEvent(BaseGraphEvent): class GraphRunPartialSucceededEvent(BaseGraphEvent): exceptions_count: int = Field(..., description="exception count") - outputs: dict[str, Any] | None = None + outputs: dict[str, object] = Field(default_factory=dict) class GraphRunAbortedEvent(BaseGraphEvent): """Event emitted when a graph run is aborted by user command.""" reason: str | None = Field(default=None, description="reason for abort") - outputs: dict[str, Any] | None = Field(default=None, description="partial outputs if any") + outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any") diff --git a/api/core/workflow/graph_events/iteration.py b/api/core/workflow/graph_events/iteration.py index 3d507dbe46..28627395fd 100644 --- a/api/core/workflow/graph_events/iteration.py +++ b/api/core/workflow/graph_events/iteration.py @@ -10,8 +10,8 @@ from .base import GraphNodeEventBase class NodeRunIterationStartedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) predecessor_node_id: str | None = None @@ -24,17 +24,17 @@ class NodeRunIterationNextEvent(GraphNodeEventBase): class NodeRunIterationSucceededEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 class NodeRunIterationFailedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/graph_events/loop.py b/api/core/workflow/graph_events/loop.py index c0b540949b..7cdc5427e2 100644 --- a/api/core/workflow/graph_events/loop.py +++ b/api/core/workflow/graph_events/loop.py @@ -10,8 +10,8 @@ from .base import GraphNodeEventBase class NodeRunLoopStartedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) predecessor_node_id: str | None = None @@ -24,17 +24,17 @@ class NodeRunLoopNextEvent(GraphNodeEventBase): class NodeRunLoopSucceededEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 class NodeRunLoopFailedEvent(GraphNodeEventBase): node_title: str start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index c6365d39c1..1d35a69c4a 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -12,7 +12,6 @@ from .base import GraphNodeEventBase class NodeRunStartedEvent(GraphNodeEventBase): node_title: str predecessor_node_id: str | None = None - parallel_mode_run_id: str | None = None agent_strategy: AgentNodeStrategyInit | None = None start_at: datetime = Field(..., description="node start time") diff --git a/api/core/workflow/node_events/agent.py b/api/core/workflow/node_events/agent.py index e5fc46ddea..bf295ec774 100644 --- a/api/core/workflow/node_events/agent.py +++ b/api/core/workflow/node_events/agent.py @@ -14,5 +14,5 @@ class AgentLogEvent(NodeEventBase): error: str | None = Field(..., description="error") status: str = Field(..., description="status") data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, Any] | None = Field(default=None, description="metadata") + metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata") node_id: str = Field(..., description="node id") diff --git a/api/core/workflow/node_events/iteration.py b/api/core/workflow/node_events/iteration.py index db0b41a43a..744ddea628 100644 --- a/api/core/workflow/node_events/iteration.py +++ b/api/core/workflow/node_events/iteration.py @@ -9,8 +9,8 @@ from .base import NodeEventBase class IterationStartedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) predecessor_node_id: str | None = None @@ -21,16 +21,16 @@ class IterationNextEvent(NodeEventBase): class IterationSucceededEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 class IterationFailedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/node_events/loop.py b/api/core/workflow/node_events/loop.py index 4e84fb0061..3ae230f9f6 100644 --- a/api/core/workflow/node_events/loop.py +++ b/api/core/workflow/node_events/loop.py @@ -9,8 +9,8 @@ from .base import NodeEventBase class LoopStartedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) predecessor_node_id: str | None = None @@ -21,16 +21,16 @@ class LoopNextEvent(NodeEventBase): class LoopSucceededEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 class LoopFailedEvent(NodeEventBase): start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, Any] | None = None - outputs: Mapping[str, Any] | None = None - metadata: Mapping[str, Any] | None = None + inputs: Mapping[str, object] = Field(default_factory=dict) + outputs: Mapping[str, object] = Field(default_factory=dict) + metadata: Mapping[str, object] = Field(default_factory=dict) steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index de6f4152c6..ce089003cf 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -2,12 +2,12 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any, ClassVar +from typing import Any, ClassVar from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus +from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState +from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus from core.workflow.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, @@ -46,11 +46,6 @@ from models.enums import UserFrom from .entities import BaseNodeData, RetryConfig -if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.enums import ErrorStrategy, NodeType - from core.workflow.node_events import NodeRunResult - logger = logging.getLogger(__name__) diff --git a/api/models/model.py b/api/models/model.py index 58a75c355c..c479bb666b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1138,7 +1138,7 @@ class Message(Base): ) @property - def retriever_resources(self) -> Any | list[Any]: + def retriever_resources(self) -> Any: return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property diff --git a/api/models/tools.py b/api/models/tools.py index 277a9d032c..545c29357d 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,4 +1,5 @@ import json +from collections.abc import Mapping from datetime import datetime from typing import TYPE_CHECKING, Any, Optional, cast from urllib.parse import urlparse @@ -314,11 +315,11 @@ class MCPToolProvider(Base): return [MCPTool(**tool) for tool in json.loads(self.tools)] @property - def provider_icon(self) -> dict[str, str] | str: + def provider_icon(self) -> Mapping[str, str] | str: from core.file import helpers as file_helpers try: - return cast(dict[str, str], json.loads(self.icon)) + return json.loads(self.icon) except json.JSONDecodeError: return file_helpers.get_signed_file_url(self.icon) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index f5fc7f951f..49d3fd57ad 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,7 @@ import json import logging -from typing import Any, Optional, Union, cast +from collections.abc import Mapping +from typing import Any, Union from yarl import URL @@ -38,7 +39,9 @@ class ToolTransformService: return str(url_prefix % {"tenant_id": tenant_id, "filename": filename}) @classmethod - def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]: + def get_tool_provider_icon_url( + cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str] + ) -> str | Mapping[str, str]: """ get tool provider icon url """ @@ -51,7 +54,7 @@ class ToolTransformService: elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: if isinstance(icon, str): - return cast(dict, json.loads(icon)) + return json.loads(icon) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @@ -94,7 +97,7 @@ class ToolTransformService: def builtin_provider_to_user_provider( cls, provider_controller: BuiltinToolProviderController | PluginToolProviderController, - db_provider: Optional[BuiltinToolProvider], + db_provider: BuiltinToolProvider | None, decrypt_credentials: bool = True, ) -> ToolProviderApiEntity: """ @@ -106,7 +109,7 @@ class ToolTransformService: name=provider_controller.entity.identity.name, description=provider_controller.entity.identity.description, icon=provider_controller.entity.identity.icon, - icon_dark=provider_controller.entity.identity.icon_dark, + icon_dark=provider_controller.entity.identity.icon_dark or "", label=provider_controller.entity.identity.label, type=ToolProviderType.BUILT_IN, masked_credentials={}, @@ -128,9 +131,10 @@ class ToolTransformService: ) } + masked_creds = {} for name in schema: - if result.masked_credentials: - result.masked_credentials[name] = "" + masked_creds[name] = "" + result.masked_credentials = masked_creds # check if the provider need credentials if not provider_controller.need_credentials: @@ -208,7 +212,7 @@ class ToolTransformService: name=provider_controller.entity.identity.name, description=provider_controller.entity.identity.description, icon=provider_controller.entity.identity.icon, - icon_dark=provider_controller.entity.identity.icon_dark, + icon_dark=provider_controller.entity.identity.icon_dark or "", label=provider_controller.entity.identity.label, type=ToolProviderType.WORKFLOW, masked_credentials={}, @@ -321,7 +325,7 @@ class ToolTransformService: @staticmethod def convert_tool_entity_to_api_entity( - tool: Union[ApiToolBundle, WorkflowTool, Tool], + tool: ApiToolBundle | WorkflowTool | Tool, tenant_id: str, labels: list[str] | None = None, ) -> ToolApiEntity: @@ -375,7 +379,7 @@ class ToolTransformService: parameters=merged_parameters, labels=labels or [], ) - elif isinstance(tool, ApiToolBundle): + else: return ToolApiEntity( author=tool.author, name=tool.operation_id or "", @@ -384,9 +388,6 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) - else: - # Handle WorkflowTool case - raise ValueError(f"Unsupported tool type: {type(tool)}") @staticmethod def convert_builtin_provider_to_credential_entity( diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 5348f729f9..17e3ebeea0 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -17,7 +17,6 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), parameters=[], description=None, - output_schema=None, has_runtime_parameters=False, ) runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) From ba5df3612bc32e4b96d5f93db4d037d68d25d015 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 11 Sep 2025 15:13:18 +0800 Subject: [PATCH 17/31] fix: tests Signed-off-by: -LAN- --- api/core/workflow/nodes/base/node.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index ce089003cf..6f2a8fc2f3 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -144,6 +144,8 @@ class Node: elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] event.id = self._node_execution_id yield event + else: + yield event except Exception as e: logger.exception("Node %s failed to run", self._node_id) result = NodeRunResult( From 8fb69429f9c9acdb06307004da03f0653d5e6171 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 11 Sep 2025 15:37:46 +0800 Subject: [PATCH 18/31] feat(graph_engine): support parallel mode in iteration node Signed-off-by: -LAN- --- .../nodes/iteration/iteration_node.py | 120 ++++++++++++++++-- 1 file changed, 107 insertions(+), 13 deletions(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 10fe7473bb..1547f8ac3e 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,5 +1,6 @@ import logging from collections.abc import Generator, Mapping, Sequence +from concurrent.futures import Future, ThreadPoolExecutor, as_completed from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, NewType, cast @@ -190,22 +191,115 @@ class IterationNode(Node): outputs: list[Any], iter_run_map: dict[str, float], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, + if self._node_data.is_parallel: + # Parallel mode execution + yield from self._execute_parallel_iterations( + iterator_list_value=iterator_list_value, outputs=outputs, - graph_engine=graph_engine, + iter_run_map=iter_run_map, ) + else: + # Sequential mode execution + for index, item in enumerate(iterator_list_value): + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + yield IterationNextEvent(index=index) - # Update the total tokens from this iteration - self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + graph_engine = self._create_graph_engine(index, item) + + # Run the iteration + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, + ) + + # Update the total tokens from this iteration + self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens + iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + + def _execute_parallel_iterations( + self, + iterator_list_value: Sequence[object], + outputs: list[Any], + iter_run_map: dict[str, float], + ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: + # Initialize outputs list with None values to maintain order + outputs.extend([None] * len(iterator_list_value)) + + # Determine the number of parallel workers + max_workers = min(self._node_data.parallel_nums, len(iterator_list_value)) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all iteration tasks + future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {} + for index, item in enumerate(iterator_list_value): + yield IterationNextEvent(index=index) + future = executor.submit( + self._execute_single_iteration_parallel, + index=index, + item=item, + ) + future_to_index[future] = index + + # Process completed iterations as they finish + for future in as_completed(future_to_index): + index = future_to_index[future] + try: + result = future.result() + iter_start_at, events, output_value, tokens_used = result + + # Update outputs at the correct index + outputs[index] = output_value + + # Yield all events from this iteration + yield from events + + # Update tokens and timing + self.graph_runtime_state.total_tokens += tokens_used + iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + + except Exception as e: + # Handle errors based on error_handle_mode + match self._node_data.error_handle_mode: + case ErrorHandleMode.TERMINATED: + # Cancel remaining futures and re-raise + for f in future_to_index: + if f != future: + f.cancel() + raise IterationNodeError(str(e)) + case ErrorHandleMode.CONTINUE_ON_ERROR: + outputs[index] = None + case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + outputs[index] = None # Will be filtered later + + # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode + if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + outputs[:] = [output for output in outputs if output is not None] + + def _execute_single_iteration_parallel( + self, + index: int, + item: Any, + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: + """Execute a single iteration in parallel mode and return results.""" + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + events: list[GraphNodeEventBase] = [] + outputs_temp: list[object] = [] + + graph_engine = self._create_graph_engine(index, item) + + # Collect events instead of yielding them directly + for event in self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs_temp, + graph_engine=graph_engine, + ): + events.append(event) + + # Get the output value from the temporary outputs list + output_value = outputs_temp[0] if outputs_temp else None + + return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens def _handle_iteration_success( self, From 872cff7bab39c60d9606c666dfd196400565a751 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 11 Sep 2025 15:40:12 +0800 Subject: [PATCH 19/31] chore(iteration_node): convert some Any to object Signed-off-by: -LAN- --- .../workflow/nodes/iteration/iteration_node.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 1547f8ac3e..524cd2c40b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -111,7 +111,7 @@ class IterationNode(Node): started_at = naive_utc_now() iter_run_map: dict[str, float] = {} - outputs: list[Any] = [] + outputs: list[object] = [] yield IterationStartedEvent( start_at=started_at, @@ -188,7 +188,7 @@ class IterationNode(Node): def _execute_iterations( self, iterator_list_value: Sequence[object], - outputs: list[Any], + outputs: list[object], iter_run_map: dict[str, float], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: if self._node_data.is_parallel: @@ -220,7 +220,7 @@ class IterationNode(Node): def _execute_parallel_iterations( self, iterator_list_value: Sequence[object], - outputs: list[Any], + outputs: list[object], iter_run_map: dict[str, float], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # Initialize outputs list with None values to maintain order @@ -279,7 +279,7 @@ class IterationNode(Node): def _execute_single_iteration_parallel( self, index: int, - item: Any, + item: object, ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: """Execute a single iteration in parallel mode and return results.""" iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -305,7 +305,7 @@ class IterationNode(Node): self, started_at: datetime, inputs: dict[str, Sequence[object]], - outputs: list[Any], + outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], ) -> Generator[NodeEventBase, None, None]: @@ -335,7 +335,7 @@ class IterationNode(Node): self, started_at: datetime, inputs: dict[str, Sequence[object]], - outputs: list[Any], + outputs: list[object], iterator_list_value: Sequence[object], iter_run_map: dict[str, float], error: IterationNodeError, @@ -502,7 +502,7 @@ class IterationNode(Node): case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: return - def _create_graph_engine(self, index: int, item: Any): + def _create_graph_engine(self, index: int, item: object): # Import dependencies from core.workflow.entities import GraphInitParams, GraphRuntimeState from core.workflow.graph import Graph From 3c668e4a5c0d77b1187c3087e80014ffe70fb937 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 11 Sep 2025 16:41:10 +0800 Subject: [PATCH 20/31] fix: update test assertions for ToolProviderApiEntity validation - Fixed test_repack_provider_entity_no_dark_icon to use empty string instead of None for icon_dark field - Updated test_builtin_provider_to_user_provider_no_credentials assertion to match actual implementation behavior where masked_credentials always contains empty strings for schema fields --- .../services/tools/test_tools_transform_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index bf25968100..827f9c010e 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -454,7 +454,7 @@ class TestToolTransformService: name=fake.company(), description=I18nObject(en_US=fake.text(max_nb_chars=100)), icon='{"background": "#FF6B6B", "content": "🔧"}', - icon_dark=None, + icon_dark="", label=I18nObject(en_US=fake.company()), type=ToolProviderType.API, masked_credentials={}, @@ -473,8 +473,8 @@ class TestToolTransformService: assert provider.icon["background"] == "#FF6B6B" assert provider.icon["content"] == "🔧" - # Verify dark icon remains None - assert provider.icon_dark is None + # Verify dark icon remains empty string + assert provider.icon_dark == "" def test_builtin_provider_to_user_provider_success( self, db_session_with_containers, mock_external_service_dependencies @@ -628,7 +628,7 @@ class TestToolTransformService: assert result is not None assert result.is_team_authorization is True assert result.allow_delete is False - assert result.masked_credentials == {} + assert result.masked_credentials == {"api_key": ""} def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies): """ From 4cdc19fd0523dc9de241329e9773e3771b18eaae Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 14 Sep 2025 04:19:24 +0800 Subject: [PATCH 21/31] feat(graph_engine): add abstract layer and dump / load methods for ready queue. --- api/core/workflow/README.md | 114 ++++++++++++++ .../workflow/graph_engine/graph_engine.py | 3 +- .../graph_engine/graph_state_manager.py | 5 +- .../graph_engine/ready_queue/__init__.py | 11 ++ .../graph_engine/ready_queue/in_memory.py | 142 ++++++++++++++++++ .../graph_engine/ready_queue/protocol.py | 88 +++++++++++ api/core/workflow/graph_engine/worker.py | 6 +- .../worker_management/worker_pool.py | 5 +- 8 files changed, 367 insertions(+), 7 deletions(-) create mode 100644 api/core/workflow/README.md create mode 100644 api/core/workflow/graph_engine/ready_queue/__init__.py create mode 100644 api/core/workflow/graph_engine/ready_queue/in_memory.py create mode 100644 api/core/workflow/graph_engine/ready_queue/protocol.py diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md new file mode 100644 index 0000000000..53e910e7b6 --- /dev/null +++ b/api/core/workflow/README.md @@ -0,0 +1,114 @@ +# Workflow + +## Project Overview + +This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control. + +## Architecture + +### Core Components + +The graph engine follows a layered architecture with strict dependency rules: + +1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution + - **Manager** - External control interface for stop/pause/resume commands + - **Worker** - Node execution runtime + - **Command Processing** - Handles control commands (abort, pause, resume) + - **Event Management** - Event propagation and layer notifications + - **Graph Traversal** - Edge processing and skip propagation + - **Response Coordinator** - Path tracking and session management + - **Layers** - Pluggable middleware (debug logging, execution limits) + - **Command Channels** - Communication channels (InMemory, Redis) + +2. **Graph** (`graph/`) - Graph structure and runtime state + - **Graph Template** - Workflow definition + - **Edge** - Node connections with conditions + - **Runtime State Protocol** - State management interface + +3. **Nodes** (`nodes/`) - Node implementations + - **Base** - Abstract node classes and variable parsing + - **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc. + +4. **Events** (`node_events/`) - Event system + - **Base** - Event protocols + - **Node Events** - Node lifecycle events + +5. **Entities** (`entities/`) - Domain models + - **Variable Pool** - Variable storage + - **Graph Init Params** - Initialization configuration + +## Key Design Patterns + +### Command Channel Pattern +External workflow control via Redis or in-memory channels: +```python +# Send stop command to running workflow +channel = RedisChannel(redis_client, f"workflow:{task_id}:commands") +channel.send_command(AbortCommand(reason="User requested")) +``` + +### Layer System +Extensible middleware for cross-cutting concerns: +```python +engine = GraphEngine(graph) +engine.add_layer(DebugLoggingLayer(level="INFO")) +engine.add_layer(ExecutionLimitsLayer(max_nodes=100)) +``` + +### Event-Driven Architecture +All node executions emit events for monitoring and integration: +- `NodeRunStartedEvent` - Node execution begins +- `NodeRunSucceededEvent` - Node completes successfully +- `NodeRunFailedEvent` - Node encounters error +- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle + +### Variable Pool +Centralized variable storage with namespace isolation: +```python +# Variables scoped by node_id +pool.add(["node1", "output"], value) +result = pool.get(["node1", "output"]) +``` + +## Import Architecture Rules + +The codebase enforces strict layering via import-linter: + +1. **Workflow Layers** (top to bottom): + - graph_engine → graph_events → graph → nodes → node_events → entities + +2. **Graph Engine Internal Layers**: + - orchestration → command_processing → event_management → graph_traversal → domain + +3. **Domain Isolation**: + - Domain models cannot import from infrastructure layers + +4. **Command Channel Independence**: + - InMemory and Redis channels must remain independent + +## Common Tasks + +### Adding a New Node Type + +1. Create node class in `nodes//` +2. Inherit from `BaseNode` or appropriate base class +3. Implement `_run()` method +4. Register in `nodes/node_mapping.py` +5. Add tests in `tests/unit_tests/core/workflow/nodes/` + +### Implementing a Custom Layer + +1. Create class inheriting from `Layer` base +2. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` +3. Add to engine via `engine.add_layer()` + +### Debugging Workflow Execution + +Enable debug logging layer: +```python +debug_layer = DebugLoggingLayer( + level="DEBUG", + include_inputs=True, + include_outputs=True +) +``` diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ff56605d3d..6e58d19fd6 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -38,6 +38,7 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel +from .ready_queue import InMemoryReadyQueue from .response_coordinator import ResponseStreamCoordinator from .worker_management import WorkerPool @@ -104,7 +105,7 @@ class GraphEngine: # === Execution Queues === # Queue for nodes ready to execute - self._ready_queue: queue.Queue[str] = queue.Queue() + self._ready_queue = InMemoryReadyQueue() # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() diff --git a/api/core/workflow/graph_engine/graph_state_manager.py b/api/core/workflow/graph_engine/graph_state_manager.py index efc3992ac9..22a3a826fc 100644 --- a/api/core/workflow/graph_engine/graph_state_manager.py +++ b/api/core/workflow/graph_engine/graph_state_manager.py @@ -2,7 +2,6 @@ Graph state manager that combines node, edge, and execution tracking. """ -import queue import threading from collections.abc import Sequence from typing import TypedDict, final @@ -10,6 +9,8 @@ from typing import TypedDict, final from core.workflow.enums import NodeState from core.workflow.graph import Edge, Graph +from .ready_queue import ReadyQueue + class EdgeStateAnalysis(TypedDict): """Analysis result for edge states.""" @@ -21,7 +22,7 @@ class EdgeStateAnalysis(TypedDict): @final class GraphStateManager: - def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None: + def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None: """ Initialize the state manager. diff --git a/api/core/workflow/graph_engine/ready_queue/__init__.py b/api/core/workflow/graph_engine/ready_queue/__init__.py new file mode 100644 index 0000000000..9b890880f5 --- /dev/null +++ b/api/core/workflow/graph_engine/ready_queue/__init__.py @@ -0,0 +1,11 @@ +""" +Ready queue implementations for GraphEngine. + +This package contains the protocol and implementations for managing +the queue of nodes ready for execution. +""" + +from .in_memory import InMemoryReadyQueue +from .protocol import ReadyQueue + +__all__ = ["InMemoryReadyQueue", "ReadyQueue"] diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/core/workflow/graph_engine/ready_queue/in_memory.py new file mode 100644 index 0000000000..90df9a0096 --- /dev/null +++ b/api/core/workflow/graph_engine/ready_queue/in_memory.py @@ -0,0 +1,142 @@ +""" +In-memory implementation of the ReadyQueue protocol. + +This implementation wraps Python's standard queue.Queue and adds +serialization capabilities for state storage. +""" + +import queue +from typing import final + + +@final +class InMemoryReadyQueue: + """ + In-memory ready queue implementation with serialization support. + + This implementation uses Python's queue.Queue internally and provides + methods to serialize and restore the queue state. + """ + + def __init__(self, maxsize: int = 0) -> None: + """ + Initialize the in-memory ready queue. + + Args: + maxsize: Maximum size of the queue (0 for unlimited) + """ + self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize) + + def put(self, item: str) -> None: + """ + Add a node ID to the ready queue. + + Args: + item: The node ID to add to the queue + """ + self._queue.put(item) + + def get(self, timeout: float | None = None) -> str: + """ + Retrieve and remove a node ID from the queue. + + Args: + timeout: Maximum time to wait for an item (None for blocking) + + Returns: + The node ID retrieved from the queue + + Raises: + queue.Empty: If timeout expires and no item is available + """ + if timeout is None: + return self._queue.get(block=True) + return self._queue.get(timeout=timeout) + + def task_done(self) -> None: + """ + Indicate that a previously retrieved task is complete. + + Used by worker threads to signal task completion for + join() synchronization. + """ + self._queue.task_done() + + def empty(self) -> bool: + """ + Check if the queue is empty. + + Returns: + True if the queue has no items, False otherwise + """ + return self._queue.empty() + + def qsize(self) -> int: + """ + Get the approximate size of the queue. + + Returns: + The approximate number of items in the queue + """ + return self._queue.qsize() + + def dumps(self) -> dict[str, object]: + """ + Serialize the queue state for storage. + + Returns: + A dictionary containing the serialized queue state + """ + # Extract all items from the queue without removing them + items: list[str] = [] + temp_items: list[str] = [] + + # Drain the queue temporarily to get all items + while not self._queue.empty(): + try: + item = self._queue.get_nowait() + temp_items.append(item) + items.append(item) + except queue.Empty: + break + + # Put items back in the same order + for item in temp_items: + self._queue.put(item) + + return { + "type": "InMemoryReadyQueue", + "version": "1.0", + "items": items, + "maxsize": self._queue.maxsize, + } + + def loads(self, data: dict[str, object]) -> None: + """ + Restore the queue state from serialized data. + + Args: + data: The serialized queue state to restore + """ + if data.get("type") != "InMemoryReadyQueue": + raise ValueError(f"Invalid serialized data type: {data.get('type')}") + + if data.get("version") != "1.0": + raise ValueError(f"Unsupported version: {data.get('version')}") + + # Clear the current queue + while not self._queue.empty(): + try: + self._queue.get_nowait() + except queue.Empty: + break + + # Restore items + items = data.get("items", []) + if not isinstance(items, list): + raise ValueError("Invalid items data: expected list") + + for item in items: + if not isinstance(item, str): + raise ValueError(f"Invalid item type: expected str, got {type(item).__name__}") + self._queue.put(item) diff --git a/api/core/workflow/graph_engine/ready_queue/protocol.py b/api/core/workflow/graph_engine/ready_queue/protocol.py new file mode 100644 index 0000000000..0e457bcf05 --- /dev/null +++ b/api/core/workflow/graph_engine/ready_queue/protocol.py @@ -0,0 +1,88 @@ +""" +ReadyQueue protocol for GraphEngine node execution queue. + +This protocol defines the interface for managing the queue of nodes ready +for execution, supporting both in-memory and persistent storage scenarios. +""" + +from typing import Protocol + + +class ReadyQueue(Protocol): + """ + Protocol for managing nodes ready for execution in GraphEngine. + + This protocol defines the interface that any ready queue implementation + must provide, enabling both in-memory queues and persistent queues + that can be serialized for state storage. + """ + + def put(self, item: str) -> None: + """ + Add a node ID to the ready queue. + + Args: + item: The node ID to add to the queue + """ + ... + + def get(self, timeout: float | None = None) -> str: + """ + Retrieve and remove a node ID from the queue. + + Args: + timeout: Maximum time to wait for an item (None for blocking) + + Returns: + The node ID retrieved from the queue + + Raises: + queue.Empty: If timeout expires and no item is available + """ + ... + + def task_done(self) -> None: + """ + Indicate that a previously retrieved task is complete. + + Used by worker threads to signal task completion for + join() synchronization. + """ + ... + + def empty(self) -> bool: + """ + Check if the queue is empty. + + Returns: + True if the queue has no items, False otherwise + """ + ... + + def qsize(self) -> int: + """ + Get the approximate size of the queue. + + Returns: + The approximate number of items in the queue + """ + ... + + def dumps(self) -> dict[str, object]: + """ + Serialize the queue state for storage. + + Returns: + A dictionary containing the serialized queue state + that can be persisted and later restored + """ + ... + + def loads(self, data: dict[str, object]) -> None: + """ + Restore the queue state from serialized data. + + Args: + data: The serialized queue state to restore + """ + ... diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index e7462309c9..42c9b936dd 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -22,6 +22,8 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent from core.workflow.nodes.base.node import Node from libs.flask_utils import preserve_flask_contexts +from .ready_queue import ReadyQueue + @final class Worker(threading.Thread): @@ -35,7 +37,7 @@ class Worker(threading.Thread): def __init__( self, - ready_queue: queue.Queue[str], + ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, worker_id: int = 0, @@ -46,7 +48,7 @@ class Worker(threading.Thread): Initialize worker thread. Args: - ready_queue: Queue containing node IDs ready for execution + ready_queue: Ready queue containing node IDs ready for execution event_queue: Queue for pushing execution events graph: Graph containing nodes to execute worker_id: Unique identifier for this worker diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/core/workflow/graph_engine/worker_management/worker_pool.py index 00328fbda1..a9aada9ea5 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/core/workflow/graph_engine/worker_management/worker_pool.py @@ -14,6 +14,7 @@ from configs import dify_config from core.workflow.graph import Graph from core.workflow.graph_events import GraphNodeEventBase +from ..ready_queue import ReadyQueue from ..worker import Worker logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ class WorkerPool: def __init__( self, - ready_queue: queue.Queue[str], + ready_queue: ReadyQueue, event_queue: queue.Queue[GraphNodeEventBase], graph: Graph, flask_app: "Flask | None" = None, @@ -49,7 +50,7 @@ class WorkerPool: Initialize the simple worker pool. Args: - ready_queue: Queue of nodes ready for execution + ready_queue: Ready queue for nodes ready for execution event_queue: Queue for worker events graph: The workflow graph flask_app: Optional Flask app for context preservation From 0f15a2bacaea6cabeca910f971b03e85caa32d5c Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sat, 13 Sep 2025 20:20:53 +0000 Subject: [PATCH 22/31] [autofix.ci] apply automated fixes --- api/core/workflow/README.md | 44 ++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/api/core/workflow/README.md b/api/core/workflow/README.md index 53e910e7b6..bef19ba90b 100644 --- a/api/core/workflow/README.md +++ b/api/core/workflow/README.md @@ -11,6 +11,7 @@ This is the workflow graph engine module of Dify, implementing a queue-based dis The graph engine follows a layered architecture with strict dependency rules: 1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution + - **Manager** - External control interface for stop/pause/resume commands - **Worker** - Node execution runtime - **Command Processing** - Handles control commands (abort, pause, resume) @@ -20,27 +21,33 @@ The graph engine follows a layered architecture with strict dependency rules: - **Layers** - Pluggable middleware (debug logging, execution limits) - **Command Channels** - Communication channels (InMemory, Redis) -2. **Graph** (`graph/`) - Graph structure and runtime state +1. **Graph** (`graph/`) - Graph structure and runtime state + - **Graph Template** - Workflow definition - **Edge** - Node connections with conditions - **Runtime State Protocol** - State management interface -3. **Nodes** (`nodes/`) - Node implementations +1. **Nodes** (`nodes/`) - Node implementations + - **Base** - Abstract node classes and variable parsing - **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc. -4. **Events** (`node_events/`) - Event system +1. **Events** (`node_events/`) - Event system + - **Base** - Event protocols - **Node Events** - Node lifecycle events -5. **Entities** (`entities/`) - Domain models +1. **Entities** (`entities/`) - Domain models + - **Variable Pool** - Variable storage - **Graph Init Params** - Initialization configuration ## Key Design Patterns ### Command Channel Pattern + External workflow control via Redis or in-memory channels: + ```python # Send stop command to running workflow channel = RedisChannel(redis_client, f"workflow:{task_id}:commands") @@ -48,7 +55,9 @@ channel.send_command(AbortCommand(reason="User requested")) ``` ### Layer System + Extensible middleware for cross-cutting concerns: + ```python engine = GraphEngine(graph) engine.add_layer(DebugLoggingLayer(level="INFO")) @@ -56,14 +65,18 @@ engine.add_layer(ExecutionLimitsLayer(max_nodes=100)) ``` ### Event-Driven Architecture + All node executions emit events for monitoring and integration: + - `NodeRunStartedEvent` - Node execution begins - `NodeRunSucceededEvent` - Node completes successfully - `NodeRunFailedEvent` - Node encounters error - `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle ### Variable Pool + Centralized variable storage with namespace isolation: + ```python # Variables scoped by node_id pool.add(["node1", "output"], value) @@ -75,15 +88,19 @@ result = pool.get(["node1", "output"]) The codebase enforces strict layering via import-linter: 1. **Workflow Layers** (top to bottom): + - graph_engine → graph_events → graph → nodes → node_events → entities -2. **Graph Engine Internal Layers**: +1. **Graph Engine Internal Layers**: + - orchestration → command_processing → event_management → graph_traversal → domain -3. **Domain Isolation**: +1. **Domain Isolation**: + - Domain models cannot import from infrastructure layers -4. **Command Channel Independence**: +1. **Command Channel Independence**: + - InMemory and Redis channels must remain independent ## Common Tasks @@ -91,20 +108,21 @@ The codebase enforces strict layering via import-linter: ### Adding a New Node Type 1. Create node class in `nodes//` -2. Inherit from `BaseNode` or appropriate base class -3. Implement `_run()` method -4. Register in `nodes/node_mapping.py` -5. Add tests in `tests/unit_tests/core/workflow/nodes/` +1. Inherit from `BaseNode` or appropriate base class +1. Implement `_run()` method +1. Register in `nodes/node_mapping.py` +1. Add tests in `tests/unit_tests/core/workflow/nodes/` ### Implementing a Custom Layer 1. Create class inheriting from `Layer` base -2. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` -3. Add to engine via `engine.add_layer()` +1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` +1. Add to engine via `engine.add_layer()` ### Debugging Workflow Execution Enable debug logging layer: + ```python debug_layer = DebugLoggingLayer( level="DEBUG", From b4ef1de30fcdbea9700e0bc8a2ae2474ea7ccbda Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 15 Sep 2025 03:05:10 +0800 Subject: [PATCH 23/31] feat(graph_engine): add ready_queue state persistence to GraphRuntimeState - Add ReadyQueueState TypedDict for type-safe queue serialization - Add ready_queue attribute to GraphRuntimeState for initializing with pre-existing queue state - Update GraphEngine to load ready_queue from GraphRuntimeState on initialization - Implement proper type hints using ReadyQueueState for better type safety - Add comprehensive tests for ready_queue loading functionality The ready_queue is read-only after initialization and allows resuming workflow execution with a pre-populated queue of nodes ready to execute. --- .../workflow/entities/graph_runtime_state.py | 16 +++- .../workflow/graph_engine/graph_engine.py | 10 +++ .../graph_engine/ready_queue/__init__.py | 4 +- .../graph_engine/ready_queue/in_memory.py | 20 ++--- .../graph_engine/ready_queue/protocol.py | 22 +++++- .../entities/test_graph_runtime_state.py | 28 +++++++ .../graph_engine/test_graph_engine.py | 75 +++++++++++++++++++ 7 files changed, 159 insertions(+), 16 deletions(-) diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index c06a62d1e7..c9ec426167 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, PrivateAttr @@ -7,6 +7,9 @@ from core.model_runtime.entities.llm_entities import LLMUsage from .variable_pool import VariablePool +if TYPE_CHECKING: + from core.workflow.graph_engine.ready_queue import ReadyQueueState + class GraphRuntimeState(BaseModel): # Private attributes to prevent direct modification @@ -16,6 +19,7 @@ class GraphRuntimeState(BaseModel): _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) _outputs: dict[str, Any] = PrivateAttr(default_factory=dict) _node_run_steps: int = PrivateAttr(default=0) + _ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict) def __init__( self, @@ -25,6 +29,7 @@ class GraphRuntimeState(BaseModel): llm_usage: LLMUsage | None = None, outputs: dict[str, Any] | None = None, node_run_steps: int = 0, + ready_queue: "ReadyQueueState | dict[str, object] | None" = None, **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -51,6 +56,10 @@ class GraphRuntimeState(BaseModel): raise ValueError("node_run_steps must be non-negative") self._node_run_steps = node_run_steps + if ready_queue is None: + ready_queue = {} + self._ready_queue = deepcopy(ready_queue) + @property def variable_pool(self) -> VariablePool: """Get the variable pool.""" @@ -133,3 +142,8 @@ class GraphRuntimeState(BaseModel): if tokens < 0: raise ValueError("tokens must be non-negative") self._total_tokens += tokens + + @property + def ready_queue(self) -> "ReadyQueueState | dict[str, object]": + """Get a copy of the ready queue state.""" + return deepcopy(self._ready_queue) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 6e58d19fd6..a7b582d803 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -106,6 +106,16 @@ class GraphEngine: # === Execution Queues === # Queue for nodes ready to execute self._ready_queue = InMemoryReadyQueue() + # Load ready queue state from GraphRuntimeState if not empty + ready_queue_state = self._graph_runtime_state.ready_queue + if ready_queue_state: + # Import ReadyQueueState here to avoid circular imports + from .ready_queue import ReadyQueueState + + # Ensure we have a ReadyQueueState object + if isinstance(ready_queue_state, dict): + ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore + self._ready_queue.loads(ready_queue_state) # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() diff --git a/api/core/workflow/graph_engine/ready_queue/__init__.py b/api/core/workflow/graph_engine/ready_queue/__init__.py index 9b890880f5..448abda286 100644 --- a/api/core/workflow/graph_engine/ready_queue/__init__.py +++ b/api/core/workflow/graph_engine/ready_queue/__init__.py @@ -6,6 +6,6 @@ the queue of nodes ready for execution. """ from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueue +from .protocol import ReadyQueue, ReadyQueueState -__all__ = ["InMemoryReadyQueue", "ReadyQueue"] +__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"] diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/core/workflow/graph_engine/ready_queue/in_memory.py index 90df9a0096..c3cfbb00ad 100644 --- a/api/core/workflow/graph_engine/ready_queue/in_memory.py +++ b/api/core/workflow/graph_engine/ready_queue/in_memory.py @@ -8,6 +8,8 @@ serialization capabilities for state storage. import queue from typing import final +from .protocol import ReadyQueueState + @final class InMemoryReadyQueue: @@ -80,12 +82,12 @@ class InMemoryReadyQueue: """ return self._queue.qsize() - def dumps(self) -> dict[str, object]: + def dumps(self) -> ReadyQueueState: """ Serialize the queue state for storage. Returns: - A dictionary containing the serialized queue state + A ReadyQueueState dictionary containing the serialized queue state """ # Extract all items from the queue without removing them items: list[str] = [] @@ -104,14 +106,14 @@ class InMemoryReadyQueue: for item in temp_items: self._queue.put(item) - return { - "type": "InMemoryReadyQueue", - "version": "1.0", - "items": items, - "maxsize": self._queue.maxsize, - } + return ReadyQueueState( + type="InMemoryReadyQueue", + version="1.0", + items=items, + maxsize=self._queue.maxsize, + ) - def loads(self, data: dict[str, object]) -> None: + def loads(self, data: ReadyQueueState) -> None: """ Restore the queue state from serialized data. diff --git a/api/core/workflow/graph_engine/ready_queue/protocol.py b/api/core/workflow/graph_engine/ready_queue/protocol.py index 0e457bcf05..d0f66d2955 100644 --- a/api/core/workflow/graph_engine/ready_queue/protocol.py +++ b/api/core/workflow/graph_engine/ready_queue/protocol.py @@ -5,7 +5,21 @@ This protocol defines the interface for managing the queue of nodes ready for execution, supporting both in-memory and persistent storage scenarios. """ -from typing import Protocol +from typing import Protocol, TypedDict + + +class ReadyQueueState(TypedDict): + """ + TypedDict for serialized ready queue state. + + This defines the structure of the dictionary returned by dumps() + and expected by loads() for ready queue serialization. + """ + + type: str # Queue implementation type (e.g., "InMemoryReadyQueue") + version: str # Serialization format version + items: list[str] # List of node IDs in the queue + maxsize: int # Maximum queue size (0 for unlimited) class ReadyQueue(Protocol): @@ -68,17 +82,17 @@ class ReadyQueue(Protocol): """ ... - def dumps(self) -> dict[str, object]: + def dumps(self) -> ReadyQueueState: """ Serialize the queue state for storage. Returns: - A dictionary containing the serialized queue state + A ReadyQueueState dictionary containing the serialized queue state that can be persisted and later restored """ ... - def loads(self, data: dict[str, object]) -> None: + def loads(self, data: ReadyQueueState) -> None: """ Restore the queue state from serialized data. diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 4d8483ce0d..067b8d8186 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,6 +4,7 @@ import pytest from core.workflow.entities.graph_runtime_state import GraphRuntimeState from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.ready_queue import ReadyQueueState class TestGraphRuntimeState: @@ -109,3 +110,30 @@ class TestGraphRuntimeState: # Original should remain unchanged assert state.get_output("nested")["level1"]["level2"]["value"] == "test" + + def test_ready_queue_property(self): + variable_pool = VariablePool() + + # Test default empty ready_queue + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + assert state.ready_queue == {} + + # Test initialization with ready_queue data as ReadyQueueState + queue_data = ReadyQueueState(type="InMemoryReadyQueue", version="1.0", items=["node1", "node2"], maxsize=0) + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=queue_data) + assert state.ready_queue == queue_data + + # Test with different ready_queue data at initialization + another_queue_data = ReadyQueueState( + type="InMemoryReadyQueue", + version="1.0", + items=["node3", "node4", "node5"], + maxsize=0, + ) + another_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=another_queue_data) + assert another_state.ready_queue == another_queue_data + + # Test immutability - modifying retrieved queue doesn't affect internal state + retrieved_queue = state.ready_queue + retrieved_queue["items"].append("node6") + assert len(state.ready_queue["items"]) == 2 # Should still be 2, not 3 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 4aa33bde26..f03c19ab1c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -744,3 +744,78 @@ def test_event_sequence_validation_with_table_tests(): else: assert result.event_sequence_match is True assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" + + +def test_ready_queue_state_loading(): + """ + Test that the ready_queue state is properly loaded from GraphRuntimeState + during GraphEngine initialization. + """ + # Use the TableTestRunner to create a proper workflow instance + runner = TableTestRunner() + + # Create a simple workflow + test_case = WorkflowTestCase( + fixture_path="simple_passthrough_workflow", + inputs={"query": "test"}, + expected_outputs={"query": "test"}, + description="Test ready_queue loading", + ) + + # Load the workflow fixture + workflow_runner = runner.workflow_runner + fixture_data = workflow_runner.load_fixture("simple_passthrough_workflow") + + # Create graph and runtime state with pre-populated ready_queue + ready_queue_data = { + "type": "InMemoryReadyQueue", + "version": "1.0", + "items": ["node1", "node2", "node3"], + "maxsize": 0, + } + + # We need to create the graph first, then create a new GraphRuntimeState with ready_queue + graph, original_runtime_state = workflow_runner.create_graph_from_fixture(fixture_data, query="test") + + # Create a new GraphRuntimeState with the ready_queue data + from core.workflow.entities import GraphRuntimeState + from core.workflow.graph_engine.ready_queue import ReadyQueueState + + # Convert ready_queue_data to ReadyQueueState + ready_queue_state = ReadyQueueState(**ready_queue_data) + + graph_runtime_state = GraphRuntimeState( + variable_pool=original_runtime_state.variable_pool, + start_at=original_runtime_state.start_at, + ready_queue=ready_queue_state, + ) + + # Update all nodes to use the new GraphRuntimeState + for node in graph.nodes.values(): + node.graph_runtime_state = graph_runtime_state + + # Create GraphEngine + command_channel = InMemoryChannel() + engine = GraphEngine( + tenant_id="test-tenant", + app_id="test-app", + workflow_id="test-workflow", + user_id="test-user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + graph=graph, + graph_config={}, + graph_runtime_state=graph_runtime_state, + command_channel=command_channel, + ) + + # Verify that the ready_queue was loaded from GraphRuntimeState + assert engine._ready_queue.qsize() == 3 + + # Verify the initial state matches what was provided + initial_queue_state = engine.graph_runtime_state.ready_queue + assert initial_queue_state["type"] == "InMemoryReadyQueue" + assert initial_queue_state["version"] == "1.0" + assert len(initial_queue_state["items"]) == 3 + assert initial_queue_state["items"] == ["node1", "node2", "node3"] From a099a35e51377af75d4612f472d7b93b2d26e0b2 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Mon, 15 Sep 2025 07:56:51 +0000 Subject: [PATCH 24/31] [autofix.ci] apply automated fixes --- api/core/ops/ops_trace_manager.py | 20 +++---- .../index_processor/index_processor_base.py | 2 +- api/core/tools/tool_engine.py | 20 +++---- api/core/tools/tool_manager.py | 4 +- api/core/tools/workflow_as_tool/tool.py | 6 +- api/models/model.py | 58 +++++++++---------- api/models/tools.py | 4 +- api/services/workflow_service.py | 10 ++-- .../graph_engine/test_context_preservation.py | 4 +- .../workflow/graph_engine/test_mock_config.py | 6 +- .../workflow/graph_engine/test_mock_nodes.py | 2 +- .../graph_engine/test_table_runner.py | 32 +++++----- 12 files changed, 84 insertions(+), 84 deletions(-) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 4805faa5ab..41491d0ed2 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -220,7 +220,7 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: Optional[TraceAppConfig] = ( + trace_config_data: TraceAppConfig | None = ( db.session.query(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -244,7 +244,7 @@ class OpsTraceManager: @classmethod def get_ops_trace_instance( cls, - app_id: Optional[Union[UUID, str]] = None, + app_id: Union[UUID, str] | None = None, ): """ Get ops trace through model config @@ -257,7 +257,7 @@ class OpsTraceManager: if app_id is None: return None - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if app is None: return None @@ -331,7 +331,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.query(App).where(App.id == app_id).first() if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -349,7 +349,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: Optional[App] = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.query(App).where(App.id == app_id).first() if not app: raise ValueError("App not found") if not app.tracing: @@ -407,11 +407,11 @@ class TraceTask: def __init__( self, trace_type: Any, - message_id: Optional[str] = None, + message_id: str | None = None, workflow_execution: Optional["WorkflowExecution"] = None, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - timer: Optional[Any] = None, + conversation_id: str | None = None, + user_id: str | None = None, + timer: Any | None = None, **kwargs, ): self.trace_type = trace_type @@ -825,7 +825,7 @@ class TraceTask: return generate_name_trace_info -trace_manager_timer: Optional[threading.Timer] = None +trace_manager_timer: threading.Timer | None = None trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index a19b756b91..8da56c99e3 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -33,7 +33,7 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 5acac20739..6c5777b2c2 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -51,10 +51,10 @@ class ToolEngine: message: Message, invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + trace_manager: TraceQueueManager | None = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> tuple[str, list[str], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -152,9 +152,9 @@ class ToolEngine: user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ Workflow invokes the tool with the given arguments. @@ -194,9 +194,9 @@ class ToolEngine: tool: Tool, tool_parameters: dict, user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: """ Invoke the tool with the given arguments. diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 18b972928a..d9ae352af5 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -157,7 +157,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -446,7 +446,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Tool: """ get tool runtime from plugin diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 73163e0e69..bb365bda19 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -61,9 +61,9 @@ class WorkflowTool(Tool): self, user_id: str, tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke the tool diff --git a/api/models/model.py b/api/models/model.py index d5f9543a52..7789e6ad9d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -74,9 +74,9 @@ class App(Base): name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying")) mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji icon = mapped_column(String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(String(255)) + icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) @@ -88,7 +88,7 @@ class App(Base): is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) tracing = mapped_column(sa.Text, nullable=True) - max_active_requests: Mapped[Optional[int]] + max_active_requests: Mapped[int | None] created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -132,7 +132,7 @@ class App(Base): return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -290,7 +290,7 @@ class App(Base): return tags or [] @property - def author_name(self) -> Optional[str]: + def author_name(self) -> str | None: if self.created_by: account = db.session.query(Account).where(Account.id == self.created_by).first() if account: @@ -333,7 +333,7 @@ class AppModelConfig(Base): file_upload = mapped_column(sa.Text) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -545,7 +545,7 @@ class RecommendedApp(Base): updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @@ -569,12 +569,12 @@ class InstalledApp(Base): created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: app = db.session.query(App).where(App.id == self.app_id).first() return app @property - def tenant(self) -> Optional[Tenant]: + def tenant(self) -> Tenant | None: tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return tenant @@ -710,7 +710,7 @@ class Conversation(Base): @property def model_config(self): model_config = {} - app_model_config: Optional[AppModelConfig] = None + app_model_config: AppModelConfig | None = None if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: @@ -844,7 +844,7 @@ class Conversation(Base): ) @property - def app(self) -> Optional[App]: + def app(self) -> App | None: with Session(db.engine, expire_on_commit=False) as session: return session.query(App).where(App.id == self.app_id).first() @@ -858,7 +858,7 @@ class Conversation(Base): return None @property - def from_account_name(self) -> Optional[str]: + def from_account_name(self) -> str | None: if self.from_account_id: account = db.session.query(Account).where(Account.id == self.from_account_id).first() if account: @@ -933,14 +933,14 @@ class Message(Base): status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) error = mapped_column(sa.Text) message_metadata = mapped_column(sa.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) from_source: Mapped[str] = mapped_column(String(255), nullable=False) - from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) - from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) + from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) + from_account_id: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) @property def inputs(self) -> dict[str, Any]: @@ -1337,9 +1337,9 @@ class MessageFile(Base): message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) - url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + url: Mapped[str | None] = mapped_column(sa.Text, nullable=True) + belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True) + upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -1356,8 +1356,8 @@ class MessageAnnotation(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) - message_id: Mapped[Optional[str]] = mapped_column(StringUUID) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) + message_id: Mapped[str | None] = mapped_column(StringUUID) question = mapped_column(sa.Text, nullable=True) content = mapped_column(sa.Text, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) @@ -1729,18 +1729,18 @@ class MessageAgentThought(Base): # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design tool_process_data = mapped_column(sa.Text, nullable=True) message = mapped_column(sa.Text, nullable=True) - message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + message_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) message_unit_price = mapped_column(sa.Numeric, nullable=True) message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) message_files = mapped_column(sa.Text, nullable=True) answer = mapped_column(sa.Text, nullable=True) - answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + answer_token: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) answer_unit_price = mapped_column(sa.Numeric, nullable=True) answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) - tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) total_price = mapped_column(sa.Numeric, nullable=True) currency = mapped_column(String, nullable=True) - latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -1838,11 +1838,11 @@ class DatasetRetrieverResource(Base): document_name = mapped_column(sa.Text, nullable=False) data_source_type = mapped_column(sa.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + score: Mapped[float | None] = mapped_column(sa.Float, nullable=True) content = mapped_column(sa.Text, nullable=False) - hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) - segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) + segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) index_node_hash = mapped_column(sa.Text, nullable=True) retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/tools.py b/api/models/tools.py index 545c29357d..6ab0b8b0c8 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -502,13 +502,13 @@ class ToolFile(TypeBase): # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID) # conversation id - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) # file key file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True, default=None) + original_url: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None) # name name: Mapped[str] = mapped_column(default="") # size diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 655effad3e..1ff0b00d65 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -84,7 +84,7 @@ class WorkflowService: ) return db.session.execute(stmt).scalar_one() - def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]: + def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: """ Get draft workflow """ @@ -104,7 +104,7 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: """ fetch published workflow by workflow_id """ @@ -126,7 +126,7 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Optional[Workflow]: + def get_published_workflow(self, app_model: App) -> Workflow | None: """ Get published workflow """ @@ -191,7 +191,7 @@ class WorkflowService: app_model: App, graph: dict, features: dict, - unique_hash: Optional[str], + unique_hash: str | None, account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], @@ -883,7 +883,7 @@ class WorkflowService: def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict - ) -> Optional[Workflow]: + ) -> Workflow | None: """ Update workflow attributes diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py b/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py index b4bc67c595..38b91f08f1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py @@ -59,7 +59,7 @@ class TestContextPreservation: context = contextvars.copy_context() # Variable to store value from worker - worker_value: Optional[str] = None + worker_value: str | None = None def worker_task() -> None: nonlocal worker_value @@ -120,7 +120,7 @@ class TestContextPreservation: test_node = MagicMock(spec=Node) # Variable to capture context inside node execution - captured_value: Optional[str] = None + captured_value: str | None = None context_available_in_node = False def mock_run() -> list[GraphNodeEventBase]: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index 2bd60cc67c..ce92d81e21 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -18,9 +18,9 @@ class NodeMockConfig: node_id: str outputs: dict[str, Any] = field(default_factory=dict) - error: Optional[str] = None + error: str | None = None delay: float = 0.0 # Simulated execution delay in seconds - custom_handler: Optional[Callable[..., dict[str, Any]]] = None + custom_handler: Callable[..., dict[str, Any]] | None = None @dataclass @@ -51,7 +51,7 @@ class MockConfig: default_template_transform_response: str = "This is mocked template transform output" default_code_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked code execution result"}) - def get_node_config(self, node_id: str) -> Optional[NodeMockConfig]: + def get_node_config(self, node_id: str) -> NodeMockConfig | None: """Get configuration for a specific node.""" return self.node_configs.get(node_id) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 8229409ffd..e944c6f83e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -64,7 +64,7 @@ class MockNodeMixin: return default_outputs - def _should_simulate_error(self) -> Optional[str]: + def _should_simulate_error(self) -> str | None: """Check if this node should simulate an error.""" if not self.mock_config: return None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index e2b646b7a6..06b9a9d917 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -60,14 +60,14 @@ class WorkflowTestCase: query: str = "" description: str = "" timeout: float = 30.0 - mock_config: Optional[MockConfig] = None + mock_config: MockConfig | None = None use_auto_mock: bool = False - expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None + expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None tags: list[str] = field(default_factory=list) skip: bool = False skip_reason: str = "" retry_count: int = 0 - custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None + custom_validator: Callable[[dict[str, Any]], bool] | None = None @dataclass @@ -76,14 +76,14 @@ class WorkflowTestResult: test_case: WorkflowTestCase success: bool - error: Optional[Exception] = None - actual_outputs: Optional[dict[str, Any]] = None + error: Exception | None = None + actual_outputs: dict[str, Any] | None = None execution_time: float = 0.0 - event_sequence_match: Optional[bool] = None - event_mismatch_details: Optional[str] = None + event_sequence_match: bool | None = None + event_mismatch_details: str | None = None events: list[GraphEngineEvent] = field(default_factory=list) retry_attempts: int = 0 - validation_details: Optional[str] = None + validation_details: str | None = None @dataclass @@ -116,7 +116,7 @@ class TestSuiteResult: class WorkflowRunner: """Core workflow execution engine for tests.""" - def __init__(self, fixtures_dir: Optional[Path] = None): + def __init__(self, fixtures_dir: Path | None = None): """Initialize the workflow runner.""" if fixtures_dir is None: # Use the new central fixtures location @@ -147,9 +147,9 @@ class WorkflowRunner: self, fixture_data: dict[str, Any], query: str = "", - inputs: Optional[dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, use_mock_factory: bool = False, - mock_config: Optional[MockConfig] = None, + mock_config: MockConfig | None = None, ) -> tuple[Graph, GraphRuntimeState]: """Create a Graph instance from fixture data.""" workflow_config = fixture_data.get("workflow", {}) @@ -240,7 +240,7 @@ class TableTestRunner: def __init__( self, - fixtures_dir: Optional[Path] = None, + fixtures_dir: Path | None = None, max_workers: int = 4, enable_logging: bool = False, log_level: str = "INFO", @@ -467,8 +467,8 @@ class TableTestRunner: self, expected_outputs: dict[str, Any], actual_outputs: dict[str, Any], - custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None, - ) -> tuple[bool, Optional[str]]: + custom_validator: Callable[[dict[str, Any]], bool] | None = None, + ) -> tuple[bool, str | None]: """ Validate actual outputs against expected outputs. @@ -517,7 +517,7 @@ class TableTestRunner: def _validate_event_sequence( self, expected_sequence: list[type[GraphEngineEvent]], actual_events: list[GraphEngineEvent] - ) -> tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: """ Validate that actual events match the expected event sequence. @@ -549,7 +549,7 @@ class TableTestRunner: self, test_cases: list[WorkflowTestCase], parallel: bool = False, - tags_filter: Optional[list[str]] = None, + tags_filter: list[str] | None = None, fail_fast: bool = False, ) -> TestSuiteResult: """ From 754d790c89f7d9e34c41a684dde8c2a0c565cbb1 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Mon, 15 Sep 2025 07:58:44 +0000 Subject: [PATCH 25/31] [autofix.ci] apply automated fixes (attempt 2/3) --- api/core/tools/tool_engine.py | 2 +- api/core/tools/workflow_as_tool/tool.py | 2 +- api/models/tools.py | 2 +- api/services/workflow_service.py | 2 +- .../core/workflow/graph_engine/test_context_preservation.py | 1 - .../unit_tests/core/workflow/graph_engine/test_mock_config.py | 2 +- .../unit_tests/core/workflow/graph_engine/test_table_runner.py | 2 +- 7 files changed, 6 insertions(+), 7 deletions(-) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 6c5777b2c2..9fb6062770 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Iterable from copy import deepcopy from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from yarl import URL diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index bb365bda19..5adf04611d 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional +from typing import Any from sqlalchemy import select diff --git a/api/models/tools.py b/api/models/tools.py index 6ab0b8b0c8..7211d7aa3a 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,7 +1,7 @@ import json from collections.abc import Mapping from datetime import datetime -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse import sqlalchemy as sa diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 1ff0b00d65..447483dfaa 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import json import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py b/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py index 38b91f08f1..c2175f048c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_context_preservation.py @@ -9,7 +9,6 @@ import contextvars import queue import threading import time -from typing import Optional from unittest.mock import MagicMock from flask import Flask, g diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index ce92d81e21..b02f90588b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -7,7 +7,7 @@ the behavior of mock nodes during testing. from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any from core.workflow.enums import NodeType diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 06b9a9d917..01a8521550 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -17,7 +17,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any, Optional +from typing import Any from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.utils.yaml_utils import _load_yaml_file From d5342927d09a9a2378fa20670e77e20e83eeaf16 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Sep 2025 01:01:38 +0800 Subject: [PATCH 26/31] chore: change _outputs type to dict[str, object] --- api/core/workflow/entities/graph_runtime_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index c9ec426167..aefdde5fc7 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -17,7 +17,7 @@ class GraphRuntimeState(BaseModel): _start_at: float = PrivateAttr() _total_tokens: int = PrivateAttr(default=0) _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) - _outputs: dict[str, Any] = PrivateAttr(default_factory=dict) + _outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object]) _node_run_steps: int = PrivateAttr(default=0) _ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict) From da87fce751f77b12532135842a3d1063bb70d293 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Sep 2025 03:00:15 +0800 Subject: [PATCH 27/31] feat(graph_engine): dump and load ready queue --- api/.importlinter | 1 - .../workflow/entities/graph_runtime_state.py | 29 +++---- .../workflow/graph_engine/graph_engine.py | 20 ++--- .../graph_engine/ready_queue/__init__.py | 3 +- .../graph_engine/ready_queue/factory.py | 35 +++++++++ .../graph_engine/ready_queue/in_memory.py | 34 ++++----- .../graph_engine/ready_queue/protocol.py | 30 ++++---- .../graph_engine/response_coordinator/path.py | 2 +- .../entities/test_graph_runtime_state.py | 42 ----------- .../graph_engine/test_graph_engine.py | 75 ------------------- 10 files changed, 89 insertions(+), 182 deletions(-) create mode 100644 api/core/workflow/graph_engine/ready_queue/factory.py diff --git a/api/.importlinter b/api/.importlinter index c5c4126330..98fe5f50bb 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -31,7 +31,6 @@ ignore_imports = core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine core.workflow.nodes.loop.loop_node -> core.workflow.graph core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels - core.workflow.entities.graph_runtime_state -> core.workflow.graph_engine.ready_queue [importlinter:contract:rsc] name = RSC diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index aefdde5fc7..2b29a36d82 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import TYPE_CHECKING, Any from pydantic import BaseModel, PrivateAttr @@ -7,9 +6,6 @@ from core.model_runtime.entities.llm_entities import LLMUsage from .variable_pool import VariablePool -if TYPE_CHECKING: - from core.workflow.graph_engine.ready_queue import ReadyQueueState - class GraphRuntimeState(BaseModel): # Private attributes to prevent direct modification @@ -19,17 +15,18 @@ class GraphRuntimeState(BaseModel): _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) _outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object]) _node_run_steps: int = PrivateAttr(default=0) - _ready_queue: "ReadyQueueState | dict[str, object]" = PrivateAttr(default_factory=dict) + _ready_queue_json: str = PrivateAttr() def __init__( self, + *, variable_pool: VariablePool, start_at: float, total_tokens: int = 0, llm_usage: LLMUsage | None = None, - outputs: dict[str, Any] | None = None, + outputs: dict[str, object] | None = None, node_run_steps: int = 0, - ready_queue: "ReadyQueueState | dict[str, object] | None" = None, + ready_queue_json: str = "", **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -56,9 +53,7 @@ class GraphRuntimeState(BaseModel): raise ValueError("node_run_steps must be non-negative") self._node_run_steps = node_run_steps - if ready_queue is None: - ready_queue = {} - self._ready_queue = deepcopy(ready_queue) + self._ready_queue_json = ready_queue_json @property def variable_pool(self) -> VariablePool: @@ -99,24 +94,24 @@ class GraphRuntimeState(BaseModel): self._llm_usage = value.model_copy() @property - def outputs(self) -> dict[str, Any]: + def outputs(self) -> dict[str, object]: """Get a copy of the outputs dictionary.""" return deepcopy(self._outputs) @outputs.setter - def outputs(self, value: dict[str, Any]) -> None: + def outputs(self, value: dict[str, object]) -> None: """Set the outputs dictionary.""" self._outputs = deepcopy(value) - def set_output(self, key: str, value: Any) -> None: + def set_output(self, key: str, value: object) -> None: """Set a single output value.""" self._outputs[key] = deepcopy(value) - def get_output(self, key: str, default: Any = None) -> Any: + def get_output(self, key: str, default: object = None) -> object: """Get a single output value.""" return deepcopy(self._outputs.get(key, default)) - def update_outputs(self, updates: dict[str, Any]) -> None: + def update_outputs(self, updates: dict[str, object]) -> None: """Update multiple output values.""" for key, value in updates.items(): self._outputs[key] = deepcopy(value) @@ -144,6 +139,6 @@ class GraphRuntimeState(BaseModel): self._total_tokens += tokens @property - def ready_queue(self) -> "ReadyQueueState | dict[str, object]": + def ready_queue_json(self) -> str: """Get a copy of the ready queue state.""" - return deepcopy(self._ready_queue) + return self._ready_queue_json diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index a7b582d803..dc85619421 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -18,6 +18,7 @@ from core.workflow.entities import GraphRuntimeState from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper +from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue from core.workflow.graph_events import ( GraphEngineEvent, GraphNodeEventBase, @@ -38,7 +39,7 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel -from .ready_queue import InMemoryReadyQueue +from .ready_queue import ReadyQueueState, create_ready_queue_from_state from .response_coordinator import ResponseStreamCoordinator from .worker_management import WorkerPool @@ -104,18 +105,13 @@ class GraphEngine: self._scale_down_idle_time = scale_down_idle_time # === Execution Queues === - # Queue for nodes ready to execute - self._ready_queue = InMemoryReadyQueue() - # Load ready queue state from GraphRuntimeState if not empty - ready_queue_state = self._graph_runtime_state.ready_queue - if ready_queue_state: - # Import ReadyQueueState here to avoid circular imports - from .ready_queue import ReadyQueueState + # Create ready queue from saved state or initialize new one + if self._graph_runtime_state.ready_queue_json == "": + self._ready_queue = InMemoryReadyQueue() + else: + ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json) + self._ready_queue = create_ready_queue_from_state(ready_queue_state) - # Ensure we have a ReadyQueueState object - if isinstance(ready_queue_state, dict): - ready_queue_state = ReadyQueueState(**ready_queue_state) # type: ignore - self._ready_queue.loads(ready_queue_state) # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() diff --git a/api/core/workflow/graph_engine/ready_queue/__init__.py b/api/core/workflow/graph_engine/ready_queue/__init__.py index 448abda286..acba0e961c 100644 --- a/api/core/workflow/graph_engine/ready_queue/__init__.py +++ b/api/core/workflow/graph_engine/ready_queue/__init__.py @@ -5,7 +5,8 @@ This package contains the protocol and implementations for managing the queue of nodes ready for execution. """ +from .factory import create_ready_queue_from_state from .in_memory import InMemoryReadyQueue from .protocol import ReadyQueue, ReadyQueueState -__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState"] +__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"] diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/core/workflow/graph_engine/ready_queue/factory.py new file mode 100644 index 0000000000..1144e1de69 --- /dev/null +++ b/api/core/workflow/graph_engine/ready_queue/factory.py @@ -0,0 +1,35 @@ +""" +Factory for creating ReadyQueue instances from serialized state. +""" + +from typing import TYPE_CHECKING + +from .in_memory import InMemoryReadyQueue +from .protocol import ReadyQueueState + +if TYPE_CHECKING: + from .protocol import ReadyQueue + + +def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue": + """ + Create a ReadyQueue instance from a serialized state. + + Args: + state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue + + Returns: + A ReadyQueue instance initialized with the given state + + Raises: + ValueError: If the queue type is unknown or version is unsupported + """ + if state.type == "InMemoryReadyQueue": + if state.version != "1.0": + raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}") + queue = InMemoryReadyQueue() + # Always pass as JSON string to loads() + queue.loads(state.model_dump_json()) + return queue + else: + raise ValueError(f"Unknown ready queue type: {state.type}") diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/core/workflow/graph_engine/ready_queue/in_memory.py index c3cfbb00ad..e01ecdc160 100644 --- a/api/core/workflow/graph_engine/ready_queue/in_memory.py +++ b/api/core/workflow/graph_engine/ready_queue/in_memory.py @@ -82,12 +82,12 @@ class InMemoryReadyQueue: """ return self._queue.qsize() - def dumps(self) -> ReadyQueueState: + def dumps(self) -> str: """ - Serialize the queue state for storage. + Serialize the queue state to a JSON string for storage. Returns: - A ReadyQueueState dictionary containing the serialized queue state + A JSON string containing the serialized queue state """ # Extract all items from the queue without removing them items: list[str] = [] @@ -106,25 +106,27 @@ class InMemoryReadyQueue: for item in temp_items: self._queue.put(item) - return ReadyQueueState( + state = ReadyQueueState( type="InMemoryReadyQueue", version="1.0", items=items, - maxsize=self._queue.maxsize, ) + return state.model_dump_json() - def loads(self, data: ReadyQueueState) -> None: + def loads(self, data: str) -> None: """ - Restore the queue state from serialized data. + Restore the queue state from a JSON string. Args: - data: The serialized queue state to restore + data: The JSON string containing the serialized queue state to restore """ - if data.get("type") != "InMemoryReadyQueue": - raise ValueError(f"Invalid serialized data type: {data.get('type')}") + state = ReadyQueueState.model_validate_json(data) - if data.get("version") != "1.0": - raise ValueError(f"Unsupported version: {data.get('version')}") + if state.type != "InMemoryReadyQueue": + raise ValueError(f"Invalid serialized data type: {state.type}") + + if state.version != "1.0": + raise ValueError(f"Unsupported version: {state.version}") # Clear the current queue while not self._queue.empty(): @@ -134,11 +136,5 @@ class InMemoryReadyQueue: break # Restore items - items = data.get("items", []) - if not isinstance(items, list): - raise ValueError("Invalid items data: expected list") - - for item in items: - if not isinstance(item, str): - raise ValueError(f"Invalid item type: expected str, got {type(item).__name__}") + for item in state.items: self._queue.put(item) diff --git a/api/core/workflow/graph_engine/ready_queue/protocol.py b/api/core/workflow/graph_engine/ready_queue/protocol.py index d0f66d2955..97d3ea6dd2 100644 --- a/api/core/workflow/graph_engine/ready_queue/protocol.py +++ b/api/core/workflow/graph_engine/ready_queue/protocol.py @@ -5,21 +5,23 @@ This protocol defines the interface for managing the queue of nodes ready for execution, supporting both in-memory and persistent storage scenarios. """ -from typing import Protocol, TypedDict +from collections.abc import Sequence +from typing import Protocol + +from pydantic import BaseModel, Field -class ReadyQueueState(TypedDict): +class ReadyQueueState(BaseModel): """ - TypedDict for serialized ready queue state. + Pydantic model for serialized ready queue state. - This defines the structure of the dictionary returned by dumps() + This defines the structure of the data returned by dumps() and expected by loads() for ready queue serialization. """ - type: str # Queue implementation type (e.g., "InMemoryReadyQueue") - version: str # Serialization format version - items: list[str] # List of node IDs in the queue - maxsize: int # Maximum queue size (0 for unlimited) + type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')") + version: str = Field(description="Serialization format version") + items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue") class ReadyQueue(Protocol): @@ -82,21 +84,21 @@ class ReadyQueue(Protocol): """ ... - def dumps(self) -> ReadyQueueState: + def dumps(self) -> str: """ - Serialize the queue state for storage. + Serialize the queue state to a JSON string for storage. Returns: - A ReadyQueueState dictionary containing the serialized queue state + A JSON string containing the serialized queue state that can be persisted and later restored """ ... - def loads(self, data: ReadyQueueState) -> None: + def loads(self, data: str) -> None: """ - Restore the queue state from serialized data. + Restore the queue state from a JSON string. Args: - data: The serialized queue state to restore + data: The JSON string containing the serialized queue state to restore """ ... diff --git a/api/core/workflow/graph_engine/response_coordinator/path.py b/api/core/workflow/graph_engine/response_coordinator/path.py index d83dd5e77b..50f2f4eb21 100644 --- a/api/core/workflow/graph_engine/response_coordinator/path.py +++ b/api/core/workflow/graph_engine/response_coordinator/path.py @@ -19,7 +19,7 @@ class Path: Note: This is an internal class not exposed in the public API. """ - edges: list[EdgeID] = field(default_factory=list) + edges: list[EdgeID] = field(default_factory=list[EdgeID]) def contains_edge(self, edge_id: EdgeID) -> bool: """Check if this path contains the given edge.""" diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 067b8d8186..2614424dc7 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,7 +4,6 @@ import pytest from core.workflow.entities.graph_runtime_state import GraphRuntimeState from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.ready_queue import ReadyQueueState class TestGraphRuntimeState: @@ -96,44 +95,3 @@ class TestGraphRuntimeState: # Test add_tokens validation with pytest.raises(ValueError): state.add_tokens(-1) - - def test_deep_copy_for_nested_objects(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test deep copy for nested dict - nested_data = {"level1": {"level2": {"value": "test"}}} - state.set_output("nested", nested_data) - - retrieved = state.get_output("nested") - retrieved["level1"]["level2"]["value"] = "modified" - - # Original should remain unchanged - assert state.get_output("nested")["level1"]["level2"]["value"] == "test" - - def test_ready_queue_property(self): - variable_pool = VariablePool() - - # Test default empty ready_queue - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - assert state.ready_queue == {} - - # Test initialization with ready_queue data as ReadyQueueState - queue_data = ReadyQueueState(type="InMemoryReadyQueue", version="1.0", items=["node1", "node2"], maxsize=0) - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=queue_data) - assert state.ready_queue == queue_data - - # Test with different ready_queue data at initialization - another_queue_data = ReadyQueueState( - type="InMemoryReadyQueue", - version="1.0", - items=["node3", "node4", "node5"], - maxsize=0, - ) - another_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time(), ready_queue=another_queue_data) - assert another_state.ready_queue == another_queue_data - - # Test immutability - modifying retrieved queue doesn't affect internal state - retrieved_queue = state.ready_queue - retrieved_queue["items"].append("node6") - assert len(state.ready_queue["items"]) == 2 # Should still be 2, not 3 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index f03c19ab1c..4aa33bde26 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -744,78 +744,3 @@ def test_event_sequence_validation_with_table_tests(): else: assert result.event_sequence_match is True assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" - - -def test_ready_queue_state_loading(): - """ - Test that the ready_queue state is properly loaded from GraphRuntimeState - during GraphEngine initialization. - """ - # Use the TableTestRunner to create a proper workflow instance - runner = TableTestRunner() - - # Create a simple workflow - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test"}, - expected_outputs={"query": "test"}, - description="Test ready_queue loading", - ) - - # Load the workflow fixture - workflow_runner = runner.workflow_runner - fixture_data = workflow_runner.load_fixture("simple_passthrough_workflow") - - # Create graph and runtime state with pre-populated ready_queue - ready_queue_data = { - "type": "InMemoryReadyQueue", - "version": "1.0", - "items": ["node1", "node2", "node3"], - "maxsize": 0, - } - - # We need to create the graph first, then create a new GraphRuntimeState with ready_queue - graph, original_runtime_state = workflow_runner.create_graph_from_fixture(fixture_data, query="test") - - # Create a new GraphRuntimeState with the ready_queue data - from core.workflow.entities import GraphRuntimeState - from core.workflow.graph_engine.ready_queue import ReadyQueueState - - # Convert ready_queue_data to ReadyQueueState - ready_queue_state = ReadyQueueState(**ready_queue_data) - - graph_runtime_state = GraphRuntimeState( - variable_pool=original_runtime_state.variable_pool, - start_at=original_runtime_state.start_at, - ready_queue=ready_queue_state, - ) - - # Update all nodes to use the new GraphRuntimeState - for node in graph.nodes.values(): - node.graph_runtime_state = graph_runtime_state - - # Create GraphEngine - command_channel = InMemoryChannel() - engine = GraphEngine( - tenant_id="test-tenant", - app_id="test-app", - workflow_id="test-workflow", - user_id="test-user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - graph=graph, - graph_config={}, - graph_runtime_state=graph_runtime_state, - command_channel=command_channel, - ) - - # Verify that the ready_queue was loaded from GraphRuntimeState - assert engine._ready_queue.qsize() == 3 - - # Verify the initial state matches what was provided - initial_queue_state = engine.graph_runtime_state.ready_queue - assert initial_queue_state["type"] == "InMemoryReadyQueue" - assert initial_queue_state["version"] == "1.0" - assert len(initial_queue_state["items"]) == 3 - assert initial_queue_state["items"] == ["node1", "node2", "node3"] From 5f263147f95de641967b2784566cf59aa8c286a0 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Sep 2025 12:51:11 +0800 Subject: [PATCH 28/31] fix: make mypy happy --- api/core/workflow/graph_engine/graph_engine.py | 3 ++- api/core/workflow/graph_engine/ready_queue/in_memory.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index dc85619421..434ad4fc5e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -39,7 +39,7 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel -from .ready_queue import ReadyQueueState, create_ready_queue_from_state +from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state from .response_coordinator import ResponseStreamCoordinator from .worker_management import WorkerPool @@ -106,6 +106,7 @@ class GraphEngine: # === Execution Queues === # Create ready queue from saved state or initialize new one + self._ready_queue: ReadyQueue if self._graph_runtime_state.ready_queue_json == "": self._ready_queue = InMemoryReadyQueue() else: diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/core/workflow/graph_engine/ready_queue/in_memory.py index e01ecdc160..f2c265ece0 100644 --- a/api/core/workflow/graph_engine/ready_queue/in_memory.py +++ b/api/core/workflow/graph_engine/ready_queue/in_memory.py @@ -8,11 +8,11 @@ serialization capabilities for state storage. import queue from typing import final -from .protocol import ReadyQueueState +from .protocol import ReadyQueue, ReadyQueueState @final -class InMemoryReadyQueue: +class InMemoryReadyQueue(ReadyQueue): """ In-memory ready queue implementation with serialization support. From b5684f1992055f43e9c94eadcdb51d7c2ad3a439 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Sep 2025 14:11:42 +0800 Subject: [PATCH 29/31] refactor(graph_engine): remove unused parameters from Engine --- .../command_channels/redis_channel.py | 6 +++- .../workflow/graph_engine/domain/__init__.py | 2 -- .../graph_engine/domain/execution_context.py | 31 ------------------- .../graph_engine/entities/commands.py | 4 +-- .../workflow/graph_engine/graph_engine.py | 26 ++-------------- .../nodes/iteration/iteration_node.py | 7 ----- api/core/workflow/nodes/loop/loop_node.py | 7 ----- api/core/workflow/workflow_entry.py | 7 ----- .../graph_engine/test_command_system.py | 9 ------ ...ditional_streaming_vs_template_workflow.py | 22 ------------- .../graph_engine/test_graph_engine.py | 23 -------------- .../workflow/graph_engine/test_mock_nodes.py | 14 --------- .../test_parallel_streaming_workflow.py | 7 ----- .../graph_engine/test_table_runner.py | 12 ------- .../graph_engine/test_tool_in_chatflow.py | 12 ------- 15 files changed, 9 insertions(+), 180 deletions(-) delete mode 100644 api/core/workflow/graph_engine/domain/execution_context.py diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index ad0aa9402c..056e17bf5d 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -97,8 +97,12 @@ class RedisChannel: Returns: Deserialized command or None if invalid """ + command_type_value = data.get("command_type") + if not isinstance(command_type_value, str): + return None + try: - command_type = CommandType(data.get("command_type")) + command_type = CommandType(command_type_value) if command_type == CommandType.ABORT: return AbortCommand(**data) diff --git a/api/core/workflow/graph_engine/domain/__init__.py b/api/core/workflow/graph_engine/domain/__init__.py index cf6d3e6aa3..9e9afe4c21 100644 --- a/api/core/workflow/graph_engine/domain/__init__.py +++ b/api/core/workflow/graph_engine/domain/__init__.py @@ -5,12 +5,10 @@ This package contains the core domain entities, value objects, and aggregates that represent the business concepts of workflow graph execution. """ -from .execution_context import ExecutionContext from .graph_execution import GraphExecution from .node_execution import NodeExecution __all__ = [ - "ExecutionContext", "GraphExecution", "NodeExecution", ] diff --git a/api/core/workflow/graph_engine/domain/execution_context.py b/api/core/workflow/graph_engine/domain/execution_context.py deleted file mode 100644 index 9bcff0fea7..0000000000 --- a/api/core/workflow/graph_engine/domain/execution_context.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -ExecutionContext value object containing immutable execution parameters. -""" - -from dataclasses import dataclass - -from core.app.entities.app_invoke_entities import InvokeFrom -from models.enums import UserFrom - - -@dataclass(frozen=True) -class ExecutionContext: - """ - Immutable value object containing the context for a graph execution. - - This encapsulates all the contextual information needed to execute a workflow, - keeping it separate from the mutable execution state. - """ - - tenant_id: str - app_id: str - workflow_id: str - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - call_depth: int - - def __post_init__(self) -> None: - """Validate execution context parameters.""" - if self.call_depth < 0: - raise ValueError("Call depth must be non-negative") diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 7e25fc0866..123ef3d449 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -5,13 +5,13 @@ This module defines command types that can be sent to a running GraphEngine instance to control its execution flow. """ -from enum import Enum +from enum import StrEnum from typing import Any from pydantic import BaseModel, Field -class CommandType(str, Enum): +class CommandType(StrEnum): """Types of commands that can be sent to GraphEngine.""" ABORT = "abort" diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 434ad4fc5e..b0daf694ce 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -8,12 +8,11 @@ Domain-Driven Design principles for improved maintainability and testability. import contextvars import logging import queue -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import final from flask import Flask, current_app -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities import GraphRuntimeState from core.workflow.enums import NodeExecutionType from core.workflow.graph import Graph @@ -27,10 +26,9 @@ from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from models.enums import UserFrom from .command_processing import AbortCommandHandler, CommandProcessor -from .domain import ExecutionContext, GraphExecution +from .domain import GraphExecution from .entities.commands import AbortCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager @@ -57,15 +55,8 @@ class GraphEngine: def __init__( self, - tenant_id: str, - app_id: str, workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, graph: Graph, - graph_config: Mapping[str, object], graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, min_workers: int | None = None, @@ -75,25 +66,12 @@ class GraphEngine: ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" - # === Domain Models === - # Execution context encapsulates workflow execution metadata - self._execution_context = ExecutionContext( - tenant_id=tenant_id, - app_id=app_id, - workflow_id=workflow_id, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - call_depth=call_depth, - ) - # Graph execution tracks the overall execution state self._graph_execution = GraphExecution(workflow_id=workflow_id) # === Core Dependencies === # Graph structure and configuration self._graph = graph - self._graph_config = graph_config self._graph_runtime_state = graph_runtime_state self._command_channel = command_channel diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 524cd2c40b..5340a5b6ce 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -551,15 +551,8 @@ class IterationNode(Node): # Create a new GraphEngine for this iteration graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=iteration_graph, - graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 2217bc205e..25b0c4f4fe 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -443,15 +443,8 @@ class LoopNode(Node): # Create a new GraphEngine for this iteration graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=loop_graph, - graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 901c830b17..f26f3a8008 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -73,15 +73,8 @@ class WorkflowEntry: self.command_channel = command_channel self.graph_engine = GraphEngine( - tenant_id=tenant_id, - app_id=app_id, workflow_id=workflow_id, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - call_depth=call_depth, graph=graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, command_channel=command_channel, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 58073ba5c3..9fec855a93 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,14 +3,12 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities import GraphRuntimeState, VariablePool from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_engine.entities.commands import AbortCommand from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent -from models.enums import UserFrom def test_abort_command(): @@ -42,15 +40,8 @@ def test_abort_command(): # Create GraphEngine with same shared runtime state engine = GraphEngine( - tenant_id="test", - app_id="test", workflow_id="test_workflow", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, graph=mock_graph, - graph_config={}, graph_runtime_state=shared_runtime_state, # Use shared instance command_channel=command_channel, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index 2b2e4fe022..70a772fc4c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -6,7 +6,6 @@ This test validates that: - When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) """ -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.enums import NodeType from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel @@ -16,7 +15,6 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from models.enums import UserFrom from .test_table_runner import TableTestRunner @@ -40,20 +38,10 @@ def test_streaming_output_with_blocking_equals_one(): use_mock_factory=True, ) - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - # Create and run the engine engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, graph=graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), ) @@ -145,20 +133,10 @@ def test_streaming_output_with_blocking_not_equals_one(): use_mock_factory=True, ) - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - # Create and run the engine engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, graph=graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 4aa33bde26..6a723999de 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,11 +10,9 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent -from models.enums import UserFrom # Import the test framework from the new module from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase @@ -460,15 +458,8 @@ def test_layer_system_basic(): # Create engine with layer engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, graph=graph, - graph_config=fixture_data.get("workflow", {}).get("graph", {}), graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), ) @@ -523,15 +514,8 @@ def test_layer_chaining(): # Create engine engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, graph=graph, - graph_config=fixture_data.get("workflow", {}).get("graph", {}), graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), ) @@ -577,15 +561,8 @@ def test_layer_error_handling(): # Create engine with faulty layer engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, graph=graph, - graph_config=fixture_data.get("workflow", {}).get("graph", {}), graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index e944c6f83e..e5ae32bbff 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -615,15 +615,8 @@ class MockIterationNode(MockNodeMixin, IterationNode): # Create a new GraphEngine for this iteration graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=iteration_graph, - graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) @@ -683,15 +676,8 @@ class MockLoopNode(MockNodeMixin, LoopNode): # Create a new GraphEngine for this iteration graph_engine = GraphEngine( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, - user_id=self.user_id, - user_from=self.user_from, - invoke_from=self.invoke_from, - call_depth=self.workflow_call_depth, graph=loop_graph, - graph_config=self.graph_config, graph_runtime_state=graph_runtime_state_copy, command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 04f0aa7f2e..d1f1f53b78 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -118,15 +118,8 @@ def test_parallel_streaming_workflow(): # Create the graph engine engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, graph=graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 01a8521550..0f3a142b1a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,7 +19,6 @@ from functools import lru_cache from pathlib import Path from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.variables import ( ArrayNumberVariable, @@ -42,7 +41,6 @@ from core.workflow.graph_events import ( ) from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory @@ -373,20 +371,10 @@ class TableTestRunner: mock_config=test_case.mock_config, ) - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - # Create and run the engine with configured worker settings engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, # Use DEBUGGER to avoid conversation_id requirement - call_depth=0, graph=graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), min_workers=self.graph_engine_min_workers, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index e227518a8e..34682ff8f9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,11 +1,9 @@ -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel from core.workflow.graph_events import ( GraphRunSucceededEvent, NodeRunStreamChunkEvent, ) -from models.enums import UserFrom from .test_table_runner import TableTestRunner @@ -23,20 +21,10 @@ def test_tool_in_chatflow(): use_mock_factory=True, ) - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - # Create and run the engine engine = GraphEngine( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, graph=graph, - graph_config=graph_config, graph_runtime_state=graph_runtime_state, command_channel=InMemoryChannel(), ) From 02d15ebd5a05ba21abb706a80c7ba90ee91e9b45 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 16 Sep 2025 19:38:10 +0800 Subject: [PATCH 30/31] feat(graph_engine): support dumps and loads in GraphExecution --- .../workflow/entities/graph_runtime_state.py | 13 ++ .../graph_engine/domain/graph_execution.py | 142 ++++++++++++++- .../workflow/graph_engine/graph_engine.py | 2 + .../test_graph_execution_serialization.py | 165 ++++++++++++++++++ 4 files changed, 319 insertions(+), 3 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index 2b29a36d82..c8fb1de20e 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -16,6 +16,7 @@ class GraphRuntimeState(BaseModel): _outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object]) _node_run_steps: int = PrivateAttr(default=0) _ready_queue_json: str = PrivateAttr() + _graph_execution_json: str = PrivateAttr() def __init__( self, @@ -27,6 +28,7 @@ class GraphRuntimeState(BaseModel): outputs: dict[str, object] | None = None, node_run_steps: int = 0, ready_queue_json: str = "", + graph_execution_json: str = "", **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -54,6 +56,7 @@ class GraphRuntimeState(BaseModel): self._node_run_steps = node_run_steps self._ready_queue_json = ready_queue_json + self._graph_execution_json = graph_execution_json @property def variable_pool(self) -> VariablePool: @@ -142,3 +145,13 @@ class GraphRuntimeState(BaseModel): def ready_queue_json(self) -> str: """Get a copy of the ready queue state.""" return self._ready_queue_json + + @property + def graph_execution_json(self) -> str: + """Get a copy of the serialized graph execution state.""" + return self._graph_execution_json + + @graph_execution_json.setter + def graph_execution_json(self, value: str) -> None: + """Set the serialized graph execution state.""" + self._graph_execution_json = value diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/core/workflow/graph_engine/domain/graph_execution.py index c375b08fe0..5951af1087 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/core/workflow/graph_engine/domain/graph_execution.py @@ -1,12 +1,94 @@ -""" -GraphExecution aggregate root managing the overall graph execution state. -""" +"""GraphExecution aggregate root managing the overall graph execution state.""" + +from __future__ import annotations from dataclasses import dataclass, field +from importlib import import_module +from typing import Literal + +from pydantic import BaseModel, Field + +from core.workflow.enums import NodeState from .node_execution import NodeExecution +class GraphExecutionErrorState(BaseModel): + """Serializable representation of an execution error.""" + + module: str = Field(description="Module containing the exception class") + qualname: str = Field(description="Qualified name of the exception class") + message: str | None = Field(default=None, description="Exception message string") + + +class NodeExecutionState(BaseModel): + """Serializable representation of a node execution entity.""" + + node_id: str + state: NodeState = Field(default=NodeState.UNKNOWN) + retry_count: int = Field(default=0) + execution_id: str | None = Field(default=None) + error: str | None = Field(default=None) + + +class GraphExecutionState(BaseModel): + """Pydantic model describing serialized GraphExecution state.""" + + type: Literal["GraphExecution"] = Field(default="GraphExecution") + version: str = Field(default="1.0") + workflow_id: str + started: bool = Field(default=False) + completed: bool = Field(default=False) + aborted: bool = Field(default=False) + error: GraphExecutionErrorState | None = Field(default=None) + node_executions: list[NodeExecutionState] = Field(default_factory=list) + + +def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: + """Convert an exception into its serializable representation.""" + + if error is None: + return None + + return GraphExecutionErrorState( + module=error.__class__.__module__, + qualname=error.__class__.__qualname__, + message=str(error), + ) + + +def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]: + """Locate an exception class from its module and qualified name.""" + + module = import_module(module_name) + attr: object = module + for part in qualname.split("."): + attr = getattr(attr, part) + + if isinstance(attr, type) and issubclass(attr, Exception): + return attr + + raise TypeError(f"{qualname} in {module_name} is not an Exception subclass") + + +def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None: + """Reconstruct an exception instance from serialized data.""" + + if state is None: + return None + + try: + exception_class = _resolve_exception_class(state.module, state.qualname) + if state.message is None: + return exception_class() + return exception_class(state.message) + except Exception: + # Fallback to RuntimeError when reconstruction fails + if state.message is None: + return RuntimeError(state.qualname) + return RuntimeError(state.message) + + @dataclass class GraphExecution: """ @@ -69,3 +151,57 @@ class GraphExecution: if not self.error: return None return str(self.error) + + def dumps(self) -> str: + """Serialize the aggregate state into a JSON string.""" + + node_states = [ + NodeExecutionState( + node_id=node_id, + state=node_execution.state, + retry_count=node_execution.retry_count, + execution_id=node_execution.execution_id, + error=node_execution.error, + ) + for node_id, node_execution in sorted(self.node_executions.items()) + ] + + state = GraphExecutionState( + workflow_id=self.workflow_id, + started=self.started, + completed=self.completed, + aborted=self.aborted, + error=_serialize_error(self.error), + node_executions=node_states, + ) + + return state.model_dump_json() + + def loads(self, data: str) -> None: + """Restore aggregate state from a serialized JSON string.""" + + state = GraphExecutionState.model_validate_json(data) + + if state.type != "GraphExecution": + raise ValueError(f"Invalid serialized data type: {state.type}") + + if state.version != "1.0": + raise ValueError(f"Unsupported serialized version: {state.version}") + + if self.workflow_id != state.workflow_id: + raise ValueError("Serialized workflow_id does not match aggregate identity") + + self.started = state.started + self.completed = state.completed + self.aborted = state.aborted + self.error = _deserialize_error(state.error) + self.node_executions = { + item.node_id: NodeExecution( + node_id=item.node_id, + state=item.state, + retry_count=item.retry_count, + execution_id=item.execution_id, + error=item.error, + ) + for item in state.node_executions + } diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b0daf694ce..1a136d4365 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -68,6 +68,8 @@ class GraphEngine: # Graph execution tracks the overall execution state self._graph_execution = GraphExecution(workflow_id=workflow_id) + if graph_runtime_state.graph_execution_json != "": + self._graph_execution.loads(graph_runtime_state.graph_execution_json) # === Core Dependencies === # Graph structure and configuration diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py new file mode 100644 index 0000000000..2388e4d57b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -0,0 +1,165 @@ +"""Unit tests for GraphExecution serialization helpers.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +from core.workflow.entities import GraphRuntimeState +from core.workflow.enums import NodeExecutionType, NodeState +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.domain import GraphExecution + + +class CustomGraphExecutionError(Exception): + """Custom exception used to verify error serialization.""" + + +def test_graph_execution_serialization_round_trip() -> None: + """GraphExecution serialization restores full aggregate state.""" + # Arrange + execution = GraphExecution(workflow_id="wf-1") + execution.start() + node_a = execution.get_or_create_node_execution("node-a") + node_a.mark_started(execution_id="exec-1") + node_a.increment_retry() + node_a.mark_failed("boom") + node_b = execution.get_or_create_node_execution("node-b") + node_b.mark_skipped() + execution.fail(CustomGraphExecutionError("serialization failure")) + + # Act + serialized = execution.dumps() + payload = json.loads(serialized) + restored = GraphExecution(workflow_id="wf-1") + restored.loads(serialized) + + # Assert + assert payload["type"] == "GraphExecution" + assert payload["version"] == "1.0" + assert restored.workflow_id == "wf-1" + assert restored.started is True + assert restored.completed is True + assert restored.aborted is False + assert isinstance(restored.error, CustomGraphExecutionError) + assert str(restored.error) == "serialization failure" + assert set(restored.node_executions) == {"node-a", "node-b"} + restored_node_a = restored.node_executions["node-a"] + assert restored_node_a.state is NodeState.TAKEN + assert restored_node_a.retry_count == 1 + assert restored_node_a.execution_id == "exec-1" + assert restored_node_a.error == "boom" + restored_node_b = restored.node_executions["node-b"] + assert restored_node_b.state is NodeState.SKIPPED + assert restored_node_b.retry_count == 0 + assert restored_node_b.execution_id is None + assert restored_node_b.error is None + + +def test_graph_execution_loads_replaces_existing_state() -> None: + """loads replaces existing runtime data with serialized snapshot.""" + # Arrange + source = GraphExecution(workflow_id="wf-2") + source.start() + source_node = source.get_or_create_node_execution("node-source") + source_node.mark_taken() + serialized = source.dumps() + + target = GraphExecution(workflow_id="wf-2") + target.start() + target.abort("pre-existing abort") + temp_node = target.get_or_create_node_execution("node-temp") + temp_node.increment_retry() + temp_node.mark_failed("temp error") + + # Act + target.loads(serialized) + + # Assert + assert target.aborted is False + assert target.error is None + assert target.started is True + assert target.completed is False + assert set(target.node_executions) == {"node-source"} + restored_node = target.node_executions["node-source"] + assert restored_node.state is NodeState.TAKEN + assert restored_node.retry_count == 0 + assert restored_node.execution_id is None + assert restored_node.error is None + + +def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None: + """GraphEngine restores GraphExecution state from runtime snapshot on init.""" + + # Arrange serialized execution state + execution = GraphExecution(workflow_id="wf-init") + execution.start() + node_state = execution.get_or_create_node_execution("serialized-node") + node_state.mark_taken() + execution.complete() + serialized = execution.dumps() + + runtime_state = GraphRuntimeState( + variable_pool=MagicMock(), + start_at=0.0, + graph_execution_json=serialized, + ) + + class DummyNode: + def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: + self.graph_runtime_state = graph_runtime_state + self.execution_type = NodeExecutionType.EXECUTABLE + self.id = "dummy-node" + self.state = NodeState.UNKNOWN + self.title = "dummy" + + class DummyGraph: + def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: + self.nodes = {"dummy-node": DummyNode(graph_runtime_state)} + self.edges: dict[str, object] = {} + self.root_node = self.nodes["dummy-node"] + + def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised + return [] + + def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised + return [] + + dummy_graph = DummyGraph(runtime_state) + + def _stub(*_args, **_kwargs): + return MagicMock() + + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub) + command_processor = MagicMock() + command_processor.register_handler = MagicMock() + monkeypatch.setattr( + "core.workflow.graph_engine.graph_engine.CommandProcessor", + lambda *_args, **_kwargs: command_processor, + ) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub) + monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub) + + # Act + engine = GraphEngine( + workflow_id="wf-init", + graph=dummy_graph, # type: ignore[arg-type] + graph_runtime_state=runtime_state, + command_channel=MagicMock(), + ) + + # Assert + assert engine._graph_execution.started is True + assert engine._graph_execution.completed is True + assert set(engine._graph_execution.node_executions) == {"serialized-node"} + restored_node = engine._graph_execution.node_executions["serialized-node"] + assert restored_node.state is NodeState.TAKEN + assert restored_node.retry_count == 0 From 73a77563509d2ab279c50ce7e031596e47c4d4f6 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 17 Sep 2025 12:45:51 +0800 Subject: [PATCH 31/31] feat(graph_engine): allow to dumps and loads RSC --- .../workflow/entities/graph_runtime_state.py | 11 +- .../workflow/graph_engine/graph_engine.py | 2 + .../response_coordinator/coordinator.py | 138 +++++++++++++++- .../test_graph_execution_serialization.py | 151 +++++++++++------- 4 files changed, 236 insertions(+), 66 deletions(-) diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index c8fb1de20e..6362f291ea 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -17,6 +17,7 @@ class GraphRuntimeState(BaseModel): _node_run_steps: int = PrivateAttr(default=0) _ready_queue_json: str = PrivateAttr() _graph_execution_json: str = PrivateAttr() + _response_coordinator_json: str = PrivateAttr() def __init__( self, @@ -29,6 +30,7 @@ class GraphRuntimeState(BaseModel): node_run_steps: int = 0, ready_queue_json: str = "", graph_execution_json: str = "", + response_coordinator_json: str = "", **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" @@ -57,6 +59,7 @@ class GraphRuntimeState(BaseModel): self._ready_queue_json = ready_queue_json self._graph_execution_json = graph_execution_json + self._response_coordinator_json = response_coordinator_json @property def variable_pool(self) -> VariablePool: @@ -151,7 +154,7 @@ class GraphRuntimeState(BaseModel): """Get a copy of the serialized graph execution state.""" return self._graph_execution_json - @graph_execution_json.setter - def graph_execution_json(self, value: str) -> None: - """Set the serialized graph execution state.""" - self._graph_execution_json = value + @property + def response_coordinator_json(self) -> str: + """Get a copy of the serialized response coordinator state.""" + return self._response_coordinator_json diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 1a136d4365..164ae41cca 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -105,6 +105,8 @@ class GraphEngine: self._response_coordinator = ResponseStreamCoordinator( variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph ) + if graph_runtime_state.response_coordinator_json != "": + self._response_coordinator.loads(graph_runtime_state.response_coordinator_json) # === Event Management === # Event manager handles both collection and emission of events diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index b5224cbc22..985992f3f1 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -9,9 +9,11 @@ import logging from collections import deque from collections.abc import Sequence from threading import RLock -from typing import TypeAlias, final +from typing import Literal, TypeAlias, final from uuid import uuid4 +from pydantic import BaseModel, Field + from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.graph import Graph @@ -28,6 +30,43 @@ NodeID: TypeAlias = str EdgeID: TypeAlias = str +class ResponseSessionState(BaseModel): + """Serializable representation of a response session.""" + + node_id: str + index: int = Field(default=0, ge=0) + + +class StreamBufferState(BaseModel): + """Serializable representation of buffered stream chunks.""" + + selector: tuple[str, ...] + events: list[NodeRunStreamChunkEvent] = Field(default_factory=list) + + +class StreamPositionState(BaseModel): + """Serializable representation for stream read positions.""" + + selector: tuple[str, ...] + position: int = Field(default=0, ge=0) + + +class ResponseStreamCoordinatorState(BaseModel): + """Serialized snapshot of ResponseStreamCoordinator.""" + + type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator") + version: str = Field(default="1.0") + response_nodes: Sequence[str] = Field(default_factory=list) + active_session: ResponseSessionState | None = None + waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) + pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) + node_execution_ids: dict[str, str] = Field(default_factory=dict) + paths_map: dict[str, list[list[str]]] = Field(default_factory=dict) + stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list) + stream_positions: Sequence[StreamPositionState] = Field(default_factory=list) + closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list) + + @final class ResponseStreamCoordinator: """ @@ -69,6 +108,8 @@ class ResponseStreamCoordinator: def register(self, response_node_id: NodeID) -> None: with self._lock: + if response_node_id in self._response_nodes: + return self._response_nodes.add(response_node_id) # Build and save paths map for this response node @@ -558,3 +599,98 @@ class ResponseStreamCoordinator: """ key = tuple(selector) return key in self._closed_streams + + def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None: + """Convert an in-memory session into its serializable form.""" + + if session is None: + return None + return ResponseSessionState(node_id=session.node_id, index=session.index) + + def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession: + """Rebuild a response session from serialized data.""" + + node = self._graph.nodes.get(session_state.node_id) + if node is None: + raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state") + + session = ResponseSession.from_node(node) + session.index = session_state.index + return session + + def dumps(self) -> str: + """Serialize coordinator state to JSON.""" + + with self._lock: + state = ResponseStreamCoordinatorState( + response_nodes=sorted(self._response_nodes), + active_session=self._serialize_session(self._active_session), + waiting_sessions=[ + session_state + for session in list(self._waiting_sessions) + if (session_state := self._serialize_session(session)) is not None + ], + pending_sessions=[ + session_state + for _, session in sorted(self._response_sessions.items()) + if (session_state := self._serialize_session(session)) is not None + ], + node_execution_ids=dict(sorted(self._node_execution_ids.items())), + paths_map={ + node_id: [path.edges.copy() for path in paths] + for node_id, paths in sorted(self._paths_maps.items()) + }, + stream_buffers=[ + StreamBufferState( + selector=selector, + events=[event.model_copy(deep=True) for event in events], + ) + for selector, events in sorted(self._stream_buffers.items()) + ], + stream_positions=[ + StreamPositionState(selector=selector, position=position) + for selector, position in sorted(self._stream_positions.items()) + ], + closed_streams=sorted(self._closed_streams), + ) + return state.model_dump_json() + + def loads(self, data: str) -> None: + """Restore coordinator state from JSON.""" + + state = ResponseStreamCoordinatorState.model_validate_json(data) + + if state.type != "ResponseStreamCoordinator": + raise ValueError(f"Invalid serialized data type: {state.type}") + + if state.version != "1.0": + raise ValueError(f"Unsupported serialized version: {state.version}") + + with self._lock: + self._response_nodes = set(state.response_nodes) + self._paths_maps = { + node_id: [Path(edges=list(path_edges)) for path_edges in paths] + for node_id, paths in state.paths_map.items() + } + self._node_execution_ids = dict(state.node_execution_ids) + + self._stream_buffers = { + tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events] + for buffer in state.stream_buffers + } + self._stream_positions = { + tuple(position.selector): position.position for position in state.stream_positions + } + for selector in self._stream_buffers: + self._stream_positions.setdefault(selector, 0) + + self._closed_streams = {tuple(selector) for selector in state.closed_streams} + + self._waiting_sessions = deque( + self._session_from_state(session_state) for session_state in state.waiting_sessions + ) + self._response_sessions = { + session_state.node_id: self._session_from_state(session_state) + for session_state in state.pending_sessions + } + self._active_session = self._session_from_state(state.active_session) if state.active_session else None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py index 2388e4d57b..6385b0b91f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -3,12 +3,16 @@ from __future__ import annotations import json +from collections import deque from unittest.mock import MagicMock -from core.workflow.entities import GraphRuntimeState -from core.workflow.enums import NodeExecutionType, NodeState -from core.workflow.graph_engine import GraphEngine +from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.graph_engine.domain import GraphExecution +from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator +from core.workflow.graph_engine.response_coordinator.path import Path +from core.workflow.graph_engine.response_coordinator.session import ResponseSession +from core.workflow.graph_events import NodeRunStreamChunkEvent +from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment class CustomGraphExecutionError(Exception): @@ -88,78 +92,103 @@ def test_graph_execution_loads_replaces_existing_state() -> None: assert restored_node.error is None -def test_graph_engine_initializes_from_serialized_execution(monkeypatch) -> None: - """GraphEngine restores GraphExecution state from runtime snapshot on init.""" +def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None: + """ResponseStreamCoordinator serialization restores coordinator internals.""" - # Arrange serialized execution state - execution = GraphExecution(workflow_id="wf-init") - execution.start() - node_state = execution.get_or_create_node_execution("serialized-node") - node_state.mark_taken() - execution.complete() - serialized = execution.dumps() - - runtime_state = GraphRuntimeState( - variable_pool=MagicMock(), - start_at=0.0, - graph_execution_json=serialized, - ) + template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])]) + template_secondary = Template(segments=[TextSegment(text="secondary")]) class DummyNode: - def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: - self.graph_runtime_state = graph_runtime_state - self.execution_type = NodeExecutionType.EXECUTABLE - self.id = "dummy-node" + def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: + self.id = node_id + self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM + self.execution_type = execution_type self.state = NodeState.UNKNOWN - self.title = "dummy" + self.title = node_id + self.template = template + + def blocks_variable_output(self, *_args) -> bool: + return False + + response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE) + response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE) + response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE) + source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE) class DummyGraph: - def __init__(self, graph_runtime_state: GraphRuntimeState) -> None: - self.nodes = {"dummy-node": DummyNode(graph_runtime_state)} + def __init__(self) -> None: + self.nodes = { + response_node1.id: response_node1, + response_node2.id: response_node2, + response_node3.id: response_node3, + source_node.id: source_node, + } self.edges: dict[str, object] = {} - self.root_node = self.nodes["dummy-node"] + self.root_node = response_node1 - def get_incoming_edges(self, node_id: str): # pragma: no cover - not exercised + def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised return [] - def get_outgoing_edges(self, node_id: str): # pragma: no cover - not exercised + def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised return [] - dummy_graph = DummyGraph(runtime_state) + graph = DummyGraph() - def _stub(*_args, **_kwargs): - return MagicMock() + def fake_from_node(cls, node: DummyNode) -> ResponseSession: + return ResponseSession(node_id=node.id, template=node.template) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.GraphStateManager", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ResponseStreamCoordinator", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventManager", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ErrorHandler", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.SkipPropagator", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EdgeProcessor", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.EventHandler", _stub) - command_processor = MagicMock() - command_processor.register_handler = MagicMock() - monkeypatch.setattr( - "core.workflow.graph_engine.graph_engine.CommandProcessor", - lambda *_args, **_kwargs: command_processor, + monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) + + coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] + coordinator._response_nodes = {"response-1", "response-2", "response-3"} + coordinator._paths_maps = { + "response-1": [Path(edges=["edge-1"])], + "response-2": [Path(edges=[])], + "response-3": [Path(edges=["edge-2", "edge-3"])], + } + + active_session = ResponseSession(node_id="response-1", template=response_node1.template) + active_session.index = 1 + coordinator._active_session = active_session + waiting_session = ResponseSession(node_id="response-2", template=response_node2.template) + coordinator._waiting_sessions = deque([waiting_session]) + pending_session = ResponseSession(node_id="response-3", template=response_node3.template) + pending_session.index = 2 + coordinator._response_sessions = {"response-3": pending_session} + + coordinator._node_execution_ids = {"response-1": "exec-1"} + event = NodeRunStreamChunkEvent( + id="exec-1", + node_id="response-1", + node_type=NodeType.ANSWER, + selector=["node-source", "text"], + chunk="chunk-1", + is_final=False, ) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.AbortCommandHandler", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.WorkerPool", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.ExecutionCoordinator", _stub) - monkeypatch.setattr("core.workflow.graph_engine.graph_engine.Dispatcher", _stub) + coordinator._stream_buffers = {("node-source", "text"): [event]} + coordinator._stream_positions = {("node-source", "text"): 1} + coordinator._closed_streams = {("node-source", "text")} - # Act - engine = GraphEngine( - workflow_id="wf-init", - graph=dummy_graph, # type: ignore[arg-type] - graph_runtime_state=runtime_state, - command_channel=MagicMock(), - ) + serialized = coordinator.dumps() - # Assert - assert engine._graph_execution.started is True - assert engine._graph_execution.completed is True - assert set(engine._graph_execution.node_executions) == {"serialized-node"} - restored_node = engine._graph_execution.node_executions["serialized-node"] - assert restored_node.state is NodeState.TAKEN - assert restored_node.retry_count == 0 + restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] + monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) + restored.loads(serialized) + + assert restored._response_nodes == {"response-1", "response-2", "response-3"} + assert restored._paths_maps["response-1"][0].edges == ["edge-1"] + assert restored._active_session is not None + assert restored._active_session.node_id == "response-1" + assert restored._active_session.index == 1 + waiting_restored = list(restored._waiting_sessions) + assert len(waiting_restored) == 1 + assert waiting_restored[0].node_id == "response-2" + assert waiting_restored[0].index == 0 + assert set(restored._response_sessions) == {"response-3"} + assert restored._response_sessions["response-3"].index == 2 + assert restored._node_execution_ids == {"response-1": "exec-1"} + assert ("node-source", "text") in restored._stream_buffers + restored_event = restored._stream_buffers[("node-source", "text")][0] + assert restored_event.chunk == "chunk-1" + assert restored._stream_positions[("node-source", "text")] == 1 + assert ("node-source", "text") in restored._closed_streams