diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 7d3d389cfc..b9e43e02a1 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -104,9 +104,14 @@ class GraphEngine: ) for item in generator: - yield item - if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.') + try: + yield item + if isinstance(item, NodeRunFailedEvent): + yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.') + return + except Exception as e: + logger.exception(f"Graph run failed: {str(e)}") + yield GraphRunFailedEvent(reason=str(e)) return # trigger graph run success event @@ -115,6 +120,7 @@ class GraphEngine: yield GraphRunFailedEvent(reason=e.error) return except Exception as e: + logger.exception("Unknown Error when graph running") yield GraphRunFailedEvent(reason=str(e)) raise e @@ -182,7 +188,22 @@ class GraphEngine: break if len(edge_mappings) == 1: - next_node_id = edge_mappings[0].target_node_id + edge = edge_mappings[0] + if edge.run_condition: + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + target_node_id=edge.target_node_id, + ) + + if not result: + break + + next_node_id = edge.target_node_id else: if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index dafc01cee9..179c11bb97 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -6,6 +6,7 @@ from core.file.file_obj import FileVar from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, + NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) @@ -31,7 +32,12 @@ class AnswerStreamProcessor: generator: Generator[GraphEngineEvent, None, None] ) -> Generator[GraphEngineEvent, None, None]: for event in generator: - if isinstance(event, NodeRunStreamChunkEvent): + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ event.route_node_state.node_id @@ -99,6 +105,9 @@ class AnswerStreamProcessor: def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: node_ids = [] for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id == self.graph.root_node_id: + continue + node_ids.append(edge.target_node_id) node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) return node_ids @@ -107,6 +116,9 @@ class AnswerStreamProcessor: """ remove target node ids until merge """ + if node_id not in self.rest_node_ids: + return + self.rest_node_ids.remove(node_id) for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id in reachable_node_ids: diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 60bb97c0b1..b89fe7f3a6 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -4,6 +4,7 @@ from collections.abc import Generator from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, + NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) @@ -26,7 +27,12 @@ class EndStreamProcessor: generator: Generator[GraphEngineEvent, None, None] ) -> Generator[GraphEngineEvent, None, None]: for event in generator: - if isinstance(event, NodeRunStreamChunkEvent): + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ event.route_node_state.node_id @@ -87,6 +93,9 @@ class EndStreamProcessor: def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: node_ids = [] for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id == self.graph.root_node_id: + continue + node_ids.append(edge.target_node_id) node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) return node_ids @@ -95,6 +104,9 @@ class EndStreamProcessor: """ remove target node ids until merge """ + if node_id not in self.rest_node_ids: + return + self.rest_node_ids.remove(node_id) for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id in reachable_node_ids: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index d077729307..3dce712408 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -5,15 +5,13 @@ from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.base_node_data_entities import BaseIterationState from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.graph_engine.entities.event import GraphRunFailedEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.event import BaseGraphEvent, GraphRunFailedEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.run_condition import RunCondition -from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.utils.condition.entities import Condition -from core.workflow.workflow_entry import WorkflowRunFailedError from models.workflow import WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) @@ -74,7 +72,7 @@ class IterationNode(BaseNode): Condition( variable_selector=[self.node_id, "index"], comparison_operator="<", - value=len(iterator_list_value) + value=str(len(iterator_list_value)) ) ] ) @@ -93,6 +91,7 @@ class IterationNode(BaseNode): ) # init graph engine + from core.workflow.graph_engine.graph_engine import GraphEngine graph_engine = GraphEngine( tenant_id=self.tenant_id, app_id=self.app_id, @@ -114,10 +113,11 @@ class IterationNode(BaseNode): rst = graph_engine.run() outputs: list[Any] = [] for event in rst: - yield event if isinstance(event, NodeRunSucceededEvent): + yield event + # handle iteration run result - if event.node_id in iteration_leaf_node_ids: + if event.route_node_state.node_id in iteration_leaf_node_ids: # append to iteration output variable list outputs.append(variable_pool.get_any(self.node_data.output_selector)) @@ -132,13 +132,23 @@ class IterationNode(BaseNode): next_index ) - variable_pool.add( - [self.node_id, 'item'], - iterator_list_value[next_index] - ) - elif isinstance(event, GraphRunFailedEvent): + if next_index < len(iterator_list_value): + variable_pool.add( + [self.node_id, 'item'], + iterator_list_value[next_index] + ) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): # iteration run failed - raise WorkflowRunFailedError(event.reason) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.reason, + ) + ) + break + else: + yield event yield RunCompletedEvent( run_result=NodeRunResult( @@ -148,14 +158,6 @@ class IterationNode(BaseNode): } ) ) - except WorkflowRunFailedError as e: - # iteration run failed - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) - ) except Exception as e: # iteration run failed logger.exception("Iteration run failed") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 2354f7e678..61b6f6bbee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -23,7 +23,7 @@ from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @patch('extensions.ext_database.db.session.remove') @patch('extensions.ext_database.db.session.close') -def test_run_parallel_in_workflow(mock_close, mock_remove, mocker): +def test_run_parallel_in_workflow(mock_close, mock_remove): graph_config = { "edges": [ { diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py b/api/tests/unit_tests/core/workflow/nodes/iteration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py new file mode 100644 index 0000000000..0d4dc8b10c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -0,0 +1,209 @@ +import time +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunResult, SystemVariable, UserFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + + +def test_run(): + graph_config = { + "edges": [{ + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }], + "nodes": [{ + "data": { + "title": "Start", + "type": "start", + "variables": [] + }, + "id": "start" + }, { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer" + }, + "id": "answer-2" + }, { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{ + "value_selector": ["sys", "query"], + "variable": "arg1" + }] + }, + "id": "tt", + }, { + "data": { + "answer": "{{#iteration-1.output#}}88888", + "title": "answer 3", + "type": "answer" + }, + "id": "answer-3", + }, { + "data": { + "conditions": [{ + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"] + }], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else" + }, + "id": "if-else", + }, { + "data": { + "answer": "no hi", + "iteration_id": "iteration-1", + "title": "answer 4", + "type": "answer" + }, + "id": "answer-4", + }, { + "data": { + "instruction": "test1", + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai" + }, + "parameters": [{ + "description": "test", + "name": "list_output", + "required": False, + "type": "array[string]" + }], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor" + }, + "id": "pe", + }] + } + + graph = Graph.init( + graph_config=graph_config + ) + + init_params = GraphInitParams( + tenant_id='1', + app_id='1', + workflow_type=WorkflowType.CHAT, + workflow_id='1', + graph_config=graph_config, + user_id='1', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0 + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'dify', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: '1' + }, user_inputs={}, environment_variables=[]) + pool.add(['pe', 'list_output'], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + start_at=time.perf_counter() + ), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + } + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'iterator_selector': 'dify' + }, + outputs={ + 'output': 'dify 123' + } + ) + + # print("") + + with patch.object(TemplateTransformNode, '_run', new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 15