diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 208736b990..0de5258930 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -98,7 +98,7 @@ class AdvancedChatAppRunner(AppRunner): # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( + workflow_engine_manager.run( workflow=workflow, user_id=application_generate_entity.user_id, user_from=UserFrom.ACCOUNT diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 4cb027fa0a..6b22d01340 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -67,7 +67,7 @@ class WorkflowAppRunner: # RUN WORKFLOW workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( + workflow_engine_manager.run( workflow=workflow, user_id=application_generate_entity.user_id, user_from=UserFrom.ACCOUNT diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 97b20643ef..cb8a35b800 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -67,23 +67,24 @@ class NodeRunFailedEvent(BaseNodeEvent): ########################################### -# Parallel Events +# Parallel Branch Events ########################################### -class BaseParallelEvent(GraphEngineEvent): +class BaseParallelBranchEvent(GraphEngineEvent): parallel_id: str = Field(..., description="parallel id") + parallel_start_node_id: str = Field(..., description="parallel start node id") -class ParallelRunStartedEvent(BaseParallelEvent): +class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): pass -class ParallelRunSucceededEvent(BaseParallelEvent): +class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent): pass -class ParallelRunFailedEvent(BaseParallelEvent): +class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): reason: str = Field(..., description="failed reason") @@ -113,4 +114,4 @@ class IterationRunFailedEvent(BaseIterationEvent): reason: str = Field(..., description="failed reason") -InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent +InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 73723faa2c..902da57772 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -23,6 +23,9 @@ from core.workflow.graph_engine.entities.event import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams @@ -237,14 +240,12 @@ class GraphEngine: # new thread for edge in edge_mappings: - run_thread = threading.Thread(target=self._run_parallel_node, kwargs={ + threading.Thread(target=self._run_parallel_node, kwargs={ 'flask_app': current_app._get_current_object(), 'parallel_id': parallel_id, 'parallel_start_node_id': edge.target_node_id, 'q': q - }) - - run_thread.start() + }).start() succeeded_count = 0 while True: @@ -253,16 +254,15 @@ class GraphEngine: if event is None: break - if isinstance(event, GraphRunSucceededEvent): + yield event + if isinstance(event, ParallelBranchRunSucceededEvent): succeeded_count += 1 if succeeded_count == len(edge_mappings): break continue - elif isinstance(event, GraphRunFailedEvent): + elif isinstance(event, ParallelBranchRunFailedEvent): raise GraphRunFailedError(event.reason) - else: - yield event except queue.Empty: continue @@ -286,6 +286,11 @@ class GraphEngine: """ with flask_app.app_context(): try: + q.put(ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id + )) + # run node generator = self._run( start_node_id=parallel_start_node_id, @@ -296,12 +301,23 @@ class GraphEngine: q.put(item) # trigger graph run success event - q.put(GraphRunSucceededEvent()) + q.put(ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id + )) except GraphRunFailedError as e: - q.put(GraphRunFailedEvent(reason=e.error)) + q.put(ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + reason=e.error + )) except Exception as e: logger.exception("Unknown Error when generating in parallel") - q.put(GraphRunFailedEvent(reason=str(e))) + q.put(ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + reason=str(e) + )) finally: db.session.remove() diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 58a785db3e..0107de43d6 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config @@ -11,7 +11,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable, UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_engine.entities.event import GraphRunFailedEvent +from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.base_node import BaseNode @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class WorkflowEntry: - def run_workflow( + def run( self, *, workflow: Workflow, @@ -37,7 +37,7 @@ class WorkflowEntry: system_inputs: Mapping[SystemVariable, Any], callbacks: Sequence[WorkflowCallback], call_depth: int = 0 - ) -> None: + ) -> Generator[GraphEngineEvent, None, None]: """ :param workflow: Workflow instance :param user_id: user id @@ -110,6 +110,7 @@ class WorkflowEntry: graph_runtime_state=graph_engine.graph_runtime_state, event=event ) + yield event except GenerateTaskStoppedException: pass except Exception as e: @@ -125,10 +126,10 @@ class WorkflowEntry: ) return - def single_step_run_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: + def single_step_run(self, workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: """ Single step run workflow node :param workflow: Workflow instance diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index cf3f429b02..5bbbd2041d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -213,7 +213,7 @@ class WorkflowService: start_at = time.perf_counter() try: - node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + node_instance, node_run_result = workflow_engine_manager.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 61b6f6bbee..c10b7f8a56 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -248,13 +248,13 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): ) ) - print("") + # print("") with patch.object(LLMNode, '_run', new=llm_generator): items = [] generator = graph_engine.run() for item in generator: - print(type(item), item) + # print(type(item), item) items.append(item) if isinstance(item, NodeRunSucceededEvent): assert item.route_node_state.status == RouteNodeState.Status.SUCCESS @@ -267,7 +267,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): ]: assert item.parallel_id is not None - assert len(items) == 17 + assert len(items) == 21 assert isinstance(items[0], GraphRunStartedEvent) assert isinstance(items[1], NodeRunStartedEvent) assert items[1].route_node_state.node_id == 'start' @@ -402,7 +402,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): ]: assert item.parallel_id is not None - assert len(items) == 19 + assert len(items) == 23 assert isinstance(items[0], GraphRunStartedEvent) assert isinstance(items[1], NodeRunStartedEvent) assert items[1].route_node_state.node_id == 'start'