diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 08795c934b..957e744aac 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -141,3 +141,15 @@ class VariablePool(BaseModel): return hash_key = hash(tuple(selector[1:])) self.variable_dictionary[selector[0]].pop(hash_key, None) + + def remove_node(self, node_id: str, /): + """ + Remove all variables associated with a given node id. + + Args: + node_id (str): The node id to remove. + + Returns: + None + """ + self.variable_dictionary.pop(node_id, None) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 85a1799719..c1e058065f 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,5 +1,6 @@ import uuid -from typing import Optional, cast +from collections.abc import Mapping +from typing import Any, Optional, cast from pydantic import BaseModel, Field @@ -61,7 +62,7 @@ class Graph(BaseModel): @classmethod def init(cls, - graph_config: dict, + graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": """ Init graph diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py index d32d3eb4f3..1a403f3e49 100644 --- a/api/core/workflow/graph_engine/entities/graph_init_params.py +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from pydantic import BaseModel, Field from core.app.entities.app_invoke_entities import InvokeFrom @@ -11,6 +14,7 @@ class GraphInitParams(BaseModel): app_id: str = Field(..., description="app id") workflow_type: WorkflowType = Field(..., description="workflow type") workflow_id: str = Field(..., description="workflow id") + graph_config: Mapping[str, Any] = Field(..., description="graph config") user_id: str = Field(..., description="user id") user_from: UserFrom = Field(..., description="user from, account or end-user") invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 0a007bad96..7d3d389cfc 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -2,8 +2,8 @@ import logging import queue import threading import time -from collections.abc import Generator -from typing import Optional +from collections.abc import Generator, Mapping +from typing import Any, Optional from flask import Flask, current_app from uritemplate.variable import VariableValue @@ -41,24 +41,29 @@ logger = logging.getLogger(__name__) class GraphEngine: - def __init__(self, tenant_id: str, - app_id: str, - workflow_type: WorkflowType, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - call_depth: int, - graph: Graph, - variable_pool: VariablePool, - max_execution_steps: int, - max_execution_time: int) -> None: + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int + ) -> None: self.graph = graph self.init_params = GraphInitParams( tenant_id=tenant_id, app_id=app_id, workflow_type=workflow_type, workflow_id=workflow_id, + graph_config=graph_config, user_id=user_id, user_from=user_from, invoke_from=invoke_from, diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 150d417c21..c6fa532c91 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -2,14 +2,12 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent, RunEvent -from core.workflow.nodes.iterable_node_mixin import IterableNodeMixin class BaseNode(ABC): @@ -26,6 +24,7 @@ class BaseNode(ABC): self.app_id = graph_init_params.app_id self.workflow_type = graph_init_params.workflow_type self.workflow_id = graph_init_params.workflow_id + self.graph_config = graph_init_params.graph_config self.user_id = graph_init_params.user_id self.user_from = graph_init_params.user_from self.invoke_from = graph_init_params.invoke_from @@ -100,37 +99,3 @@ class BaseNode(ABC): :return: """ return self._node_type - - -class BaseIterationNode(BaseNode, IterableNodeMixin): - @abstractmethod - def _run(self) -> BaseIterationState: - """ - Run node - :return: - """ - raise NotImplementedError - - def run(self) -> BaseIterationState: - """ - Run node entry - :return: - """ - return self._run(variable_pool=self.graph_runtime_state.variable_pool) - - def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - return self._get_next_iteration(variable_pool, state) - - @abstractmethod - def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - raise NotImplementedError diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e6ca0335dd..d077729307 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,16 +1,25 @@ +import logging from typing import Any, cast +from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.base_node_data_entities import BaseIterationState from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode -from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from core.workflow.graph_engine.entities.event import GraphRunFailedEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.utils.condition.entities import Condition +from core.workflow.workflow_entry import WorkflowRunFailedError from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) -class IterationNode(BaseIterationNode): + +class IterationNode(BaseNode): """ Iteration Node. """ @@ -22,92 +31,144 @@ class IterationNode(BaseIterationNode): Run the node. """ self.node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(self.node_data.iterator_selector) + iterator_list_value = self.graph_runtime_state.variable_pool.get_any(self.node_data.iterator_selector) - if not isinstance(iterator, list): - raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") + if not isinstance(iterator_list_value, list): + raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ - 'iterator_selector': iterator - }, outputs=[], metadata=IterationState.MetaData( - iterator_length=len(iterator) if iterator is not None else 0 - )) - - self._set_current_iteration_variable(self.graph_runtime_state.variable_pool, state) - return state + root_node_id = self.node_data.start_node_id + graph_config = self.graph_config - def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - # resolve current output - self._resolve_current_output(variable_pool, state) - # move to next iteration - self._next_iteration(variable_pool, state) + # init graph + iteration_graph = Graph.init( + graph_config=graph_config, + root_node_id=root_node_id + ) - node_data = cast(IterationNodeData, self.node_data) - if self._reached_iteration_limit(variable_pool, state): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'output': jsonable_encoder(state.outputs) - } + if not iteration_graph: + raise ValueError('iteration graph not found') + + leaf_node_ids = iteration_graph.get_leaf_node_ids() + iteration_leaf_node_ids = [] + for leaf_node_id in leaf_node_ids: + node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id) + if not node_config: + continue + + leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id") + if not leaf_node_iteration_id: + continue + + if leaf_node_iteration_id != self.node_id: + continue + + iteration_leaf_node_ids.append(leaf_node_id) + + # add condition of end nodes to root node + iteration_graph.add_extra_edge( + source_node_id=leaf_node_id, + target_node_id=root_node_id, + run_condition=RunCondition( + type="condition", + conditions=[ + Condition( + variable_selector=[self.node_id, "index"], + comparison_operator="<", + value=len(iterator_list_value) + ) + ] + ) ) - - return node_data.start_node_id - - def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): - """ - Set current iteration variable. - :variable_pool: variable pool - """ - node_data = cast(IterationNodeData, self.node_data) - variable_pool.add((self.node_id, 'index'), state.index) - # get the iterator value - iterator = variable_pool.get_any(node_data.iterator_selector) + variable_pool = self.graph_runtime_state.variable_pool - if iterator is None or not isinstance(iterator, list): - return - - if state.index < len(iterator): - variable_pool.add((self.node_id, 'item'), iterator[state.index]) + # append iteration variable (item, index) to variable pool + variable_pool.add( + [self.node_id, 'index'], + 0 + ) + variable_pool.add( + [self.node_id, 'item'], + iterator_list_value[0] + ) - def _next_iteration(self, variable_pool: VariablePool, state: IterationState): - """ - Move to next iteration. - :param variable_pool: variable pool - """ - state.index += 1 - self._set_current_iteration_variable(variable_pool, state) + # init graph engine + graph_engine = GraphEngine( + tenant_id=self.tenant_id, + app_id=self.app_id, + workflow_type=self.workflow_type, + workflow_id=self.workflow_id, + user_id=self.user_id, + user_from=self.user_from, + invoke_from=self.invoke_from, + call_depth=self.workflow_call_depth, + graph=iteration_graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + ) - def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): - """ - Check if iteration limit is reached. - :return: True if iteration limit is reached, False otherwise - """ - node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(node_data.iterator_selector) + try: + # run workflow + rst = graph_engine.run() + outputs: list[Any] = [] + for event in rst: + yield event + if isinstance(event, NodeRunSucceededEvent): + # handle iteration run result + if event.node_id in iteration_leaf_node_ids: + # append to iteration output variable list + outputs.append(variable_pool.get_any(self.node_data.output_selector)) - if iterator is None or not isinstance(iterator, list): - return True + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove_node(node_id) - return state.index >= len(iterator) - - def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): - """ - Resolve current output. - :param variable_pool: variable pool - """ - output_selector = cast(IterationNodeData, self.node_data).output_selector - output = variable_pool.get_any(output_selector) - # clear the output for this iteration - variable_pool.remove([self.node_id] + output_selector[1:]) - state.current_output = output - if output is not None: - state.outputs.append(output) + # move to next iteration + next_index = variable_pool.get_any([self.node_id, 'index']) + 1 + variable_pool.add( + [self.node_id, 'index'], + next_index + ) + + variable_pool.add( + [self.node_id, 'item'], + iterator_list_value[next_index] + ) + elif isinstance(event, GraphRunFailedEvent): + # iteration run failed + raise WorkflowRunFailedError(event.reason) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + 'output': jsonable_encoder(outputs) + } + ) + ) + except WorkflowRunFailedError as e: + # iteration run failed + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + except Exception as e: + # iteration run failed + logger.exception("Iteration run failed") + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + finally: + # remove iteration variable (item, index) from variable pool after iteration run completed + variable_pool.remove([self.node_id, 'index']) + variable_pool.remove([self.node_id, 'item']) @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: @@ -119,19 +180,3 @@ class IterationNode(BaseIterationNode): return { 'input_selector': node_data.iterator_selector, } - - @classmethod - def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: - """ - Get conditions. - """ - node_id = node_config.get('id') - if not node_id: - return [] - - return [Condition( - variable_selector=[node_id, 'index'], - comparison_operator="≤", - value_type="value_selector", - value_selector=node_config.get('data', {}).get('iterator_selector') - )] diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 0c20312d84..526404e30d 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,13 +1,12 @@ from typing import Any -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.loop.entities import LoopNodeData, LoopState from core.workflow.utils.condition.entities import Condition -class LoopNode(BaseIterationNode): +class LoopNode(BaseNode): """ Loop Node. """ @@ -17,12 +16,6 @@ class LoopNode(BaseIterationNode): def _run(self) -> LoopState: return super()._run() - def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - """ - pass - @classmethod def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: """ diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 4fe2165777..1a788cd428 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -98,9 +98,10 @@ class WorkflowEntry: invoke_from=invoke_from, call_depth=call_depth, graph=graph, + graph_config=graph_config, variable_pool=variable_pool, - max_execution_steps=current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS"), - max_execution_time=current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME ) # init workflow run @@ -155,7 +156,6 @@ class WorkflowEntry: ) predecessor_node: BaseNode | None = None - current_iteration_node: BaseIterationNode | None = None has_entry_node = False max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME @@ -610,7 +610,7 @@ class WorkflowEntry: for callback in callbacks: callback.on_workflow_run_started() - def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None: + def _workflow_run_success(self, callbacks: Sequence[BaseWorkflowCallback]) -> None: """ Workflow run success :param callbacks: workflow callbacks diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index d69dcc1c28..2354f7e678 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -211,6 +211,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove, mocker): app_id="222", workflow_type=WorkflowType.WORKFLOW, workflow_id="333", + graph_config=graph_config, user_id="444", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, @@ -372,6 +373,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): app_id="222", workflow_type=WorkflowType.CHAT, workflow_id="333", + graph_config=graph_config, user_id="444", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, @@ -531,6 +533,7 @@ def test_run_branch(mock_close, mock_remove): app_id="222", workflow_type=WorkflowType.CHAT, workflow_id="333", + graph_config=graph_config, user_id="444", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index f9e0989868..f76f9db312 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -46,6 +46,7 @@ def test_execute_answer(): app_id='1', workflow_type=WorkflowType.WORKFLOW, workflow_id='1', + graph_config=graph_config, user_id='1', user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 9f1e5c4517..d8111f7d73 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -34,6 +34,7 @@ def test_execute_if_else_result_true(): app_id='1', workflow_type=WorkflowType.WORKFLOW, workflow_id='1', + graph_config=graph_config, user_id='1', user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -237,6 +238,7 @@ def test_execute_if_else_result_false(): app_id='1', workflow_type=WorkflowType.WORKFLOW, workflow_id='1', + graph_config=graph_config, user_id='1', user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER,