diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py deleted file mode 100644 index 8d43155a08..0000000000 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager._publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager._publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self._queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py deleted file mode 100644 index 4472a7e9b5..0000000000 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager.publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager.publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - pass diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index 2e6431d6d0..dc9583c057 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -5,6 +5,11 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, \ + GraphRunFailedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, NodeRunFailedEvent, NodeRunStreamChunkEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -20,87 +25,140 @@ class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: self.current_node_id = None - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self.print_text("\n[on_workflow_run_started]", color='pink') + def on_event( + self, + graph: Graph, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + event: GraphEngineEvent + ) -> None: + if isinstance(event, GraphRunStartedEvent): + self.print_text("\n[on_workflow_run_started]", color='pink') + elif isinstance(event, GraphRunSucceededEvent): + self.print_text("\n[on_workflow_run_succeeded]", color='green') + elif isinstance(event, GraphRunFailedEvent): + self.print_text(f"\n[on_workflow_run_failed] reason: {event.reason}", color='red') + elif isinstance(event, NodeRunStartedEvent): + self.on_workflow_node_execute_started( + graph=graph, + event=event + ) + elif isinstance(event, NodeRunSucceededEvent): + self.on_workflow_node_execute_succeeded( + graph=graph, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + event=event + ) + elif isinstance(event, NodeRunFailedEvent): + self.on_workflow_node_execute_failed( + graph=graph, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + event=event + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self.on_node_text_chunk( + graph=graph, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + event=event + ) - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self.print_text("\n[on_workflow_run_succeeded]", color='green') - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self.print_text("\n[on_workflow_run_failed]", color='red') - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: + def on_workflow_node_execute_started( + self, + graph: Graph, + event: NodeRunStartedEvent + ) -> None: """ Workflow node execute started """ - self.print_text("\n[on_workflow_node_execute_started]", color='yellow') - self.print_text(f"Node ID: {node_id}", color='yellow') - self.print_text(f"Type: {node_type.value}", color='yellow') - self.print_text(f"Index: {node_run_index}", color='yellow') - if predecessor_node_id: - self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow') + route_node_state = event.route_node_state + node_config = graph.node_id_config_mapping.get(route_node_state.node_id) + node_type = None + if node_config: + node_type = node_config.get("data", {}).get("type") - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: + self.print_text("\n[on_workflow_node_execute_started]", color='yellow') + self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow') + self.print_text(f"Type: {node_type}", color='yellow') + + def on_workflow_node_execute_succeeded( + self, + graph: Graph, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + event: NodeRunSucceededEvent + ) -> None: """ Workflow node execute succeeded """ - self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') - self.print_text(f"Node ID: {node_id}", color='green') - self.print_text(f"Type: {node_type.value}", color='green') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green') - self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}", - color='green') + route_node_state = event.route_node_state + node_config = graph.node_id_config_mapping.get(route_node_state.node_id) + node_type = None + if node_config: + node_type = node_config.get("data", {}).get("type") - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: + self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') + self.print_text(f"Node ID: {route_node_state.node_id}", color='green') + self.print_text(f"Type: {node_type.value}", color='green') + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color='green') + self.print_text(f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", color='green') + self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color='green') + self.print_text(f"Metadata: {jsonable_encoder(node_run_result.execution_metadata) if node_run_result.execution_metadata else ''}", + color='green') + + def on_workflow_node_execute_failed( + self, + graph: Graph, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + event: NodeRunFailedEvent + ) -> None: """ Workflow node execute failed """ - self.print_text("\n[on_workflow_node_execute_failed]", color='red') - self.print_text(f"Node ID: {node_id}", color='red') - self.print_text(f"Type: {node_type.value}", color='red') - self.print_text(f"Error: {error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red') + route_node_state = event.route_node_state + node_config = graph.node_id_config_mapping.get(route_node_state.node_id) + node_type = None + if node_config: + node_type = node_config.get("data", {}).get("type") - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: + self.print_text("\n[on_workflow_node_execute_failed]", color='red') + self.print_text(f"Node ID: {route_node_state.node_id}", color='red') + self.print_text(f"Type: {node_type.value}", color='red') + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Error: {node_run_result.error}", color='red') + self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color='red') + self.print_text(f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", color='red') + self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color='red') + + def on_node_text_chunk( + self, + graph: Graph, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + event: NodeRunStreamChunkEvent + ) -> None: """ Publish text chunk """ - if not self.current_node_id or self.current_node_id != node_id: - self.current_node_id = node_id + route_node_state = event.route_node_state + if not self.current_node_id or self.current_node_id != route_node_state.node_id: + self.current_node_id = route_node_state.node_id self.print_text('\n[on_node_text_chunk]') - self.print_text(f"Node ID: {node_id}") - self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}") + self.print_text(f"Node ID: {route_node_state.node_id}") - self.print_text(text, color="pink", end="") + node_run_result = route_node_state.node_run_result + if node_run_result: + self.print_text(f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}") + + self.print_text(event.chunk_content, color="pink", end="") def on_workflow_iteration_started(self, node_id: str, @@ -135,13 +193,6 @@ class WorkflowLoggingCallback(WorkflowCallback): """ self.print_text("\n[on_workflow_iteration_completed]", color='blue') - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self.print_text("\n[on_workflow_event]", color='blue') - self.print_text(f"Event: {jsonable_encoder(event)}", color='blue') - def print_text( self, text: str, color: Optional[str] = None, end: str = "\n" ) -> None: diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 6db8adf4c2..71f8804a2d 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,116 +1,21 @@ from abc import ABC, abstractmethod -from typing import Any, Optional -from core.app.entities.queue_entities import AppQueueEvent -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import GraphEngineEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState class WorkflowCallback(ABC): @abstractmethod - def on_workflow_run_started(self) -> None: + def on_event( + self, + graph: Graph, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + event: GraphEngineEvent + ) -> None: """ - Workflow run started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - raise NotImplementedError - - @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any], - ) -> None: - """ - Publish iteration next - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - raise NotImplementedError - - @abstractmethod - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event + Published event """ raise NotImplementedError diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 3e669ea49c..1a50900a4f 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Any from pydantic import BaseModel, Field @@ -21,10 +21,6 @@ class GraphRunStartedEvent(BaseGraphEvent): pass -class GraphRunBackToRootEvent(BaseGraphEvent): - pass - - class GraphRunSucceededEvent(BaseGraphEvent): pass @@ -104,6 +100,11 @@ class IterationRunStartedEvent(BaseIterationEvent): pass +class IterationRunNextEvent(BaseIterationEvent): + index: int = Field(..., description="index") + pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") + + class IterationRunSucceededEvent(BaseIterationEvent): pass diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b9e43e02a1..73723faa2c 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -30,8 +30,6 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor - -# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import node_classes from extensions.ext_database import db @@ -78,10 +76,6 @@ class GraphEngine: self.max_execution_steps = max_execution_steps self.max_execution_time = max_execution_time - def run_in_block_mode(self): - # TODO convert generator to result - pass - def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 3dce712408..be56dae500 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -5,7 +5,8 @@ from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.base_node_data_entities import BaseIterationState from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.graph_engine.entities.event import BaseGraphEvent, GraphRunFailedEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.event import BaseGraphEvent, GraphRunFailedEvent, NodeRunSucceededEvent, \ + IterationRunStartedEvent, IterationRunSucceededEvent, IterationRunFailedEvent, IterationRunNextEvent from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes.base_node import BaseNode @@ -108,6 +109,16 @@ class IterationNode(BaseNode): max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME ) + yield IterationRunStartedEvent( + iteration_id=self.node_id, + ) + + yield IterationRunNextEvent( + iteration_id=self.node_id, + index=0, + output=None + ) + try: # run workflow rst = graph_engine.run() @@ -119,7 +130,8 @@ class IterationNode(BaseNode): # handle iteration run result if event.route_node_state.node_id in iteration_leaf_node_ids: # append to iteration output variable list - outputs.append(variable_pool.get_any(self.node_data.output_selector)) + current_iteration_output = variable_pool.get_any(self.node_data.output_selector) + outputs.append(current_iteration_output) # remove all nodes outputs from variable pool for node_id in iteration_graph.node_ids: @@ -137,9 +149,20 @@ class IterationNode(BaseNode): [self.node_id, 'item'], iterator_list_value[next_index] ) + + yield IterationRunNextEvent( + iteration_id=self.node_id, + index=next_index, + pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None + ) elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # iteration run failed + yield IterationRunFailedEvent( + iteration_id=self.node_id, + reason=event.reason, + ) + yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -150,6 +173,10 @@ class IterationNode(BaseNode): else: yield event + yield IterationRunSucceededEvent( + iteration_id=self.node_id, + ) + yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -161,6 +188,11 @@ class IterationNode(BaseNode): except Exception as e: # iteration run failed logger.exception("Iteration run failed") + yield IterationRunFailedEvent( + iteration_id=self.node_id, + reason=str(e), + ) + yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 1a788cd428..58a785db3e 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,5 +1,4 @@ import logging -import time from collections.abc import Mapping, Sequence from typing import Any, Optional, cast @@ -8,22 +7,18 @@ from core.app.app_config.entities import FileExtraConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable -from core.workflow.entities.variable_pool import VariablePool, VariableValue -from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState +from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable, UserFrom +from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph_engine.entities.event import GraphRunFailedEvent from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom -from core.workflow.nodes.iteration.entities import IterationState +from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.node_mapping import node_classes -from core.workflow.nodes.start.start_node import StartNode -from extensions.ext_database import db from models.workflow import ( Workflow, - WorkflowNodeExecutionStatus, WorkflowType, ) @@ -52,7 +47,6 @@ class WorkflowEntry: :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files :param call_depth: call depth - :param variable_pool: variable pool """ # fetch workflow graph graph_config = workflow.graph_dict @@ -68,13 +62,6 @@ class WorkflowEntry: if not isinstance(graph_config.get('edges'), list): raise ValueError('edges in workflow graph must be a list') - # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - ) - workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) @@ -87,6 +74,13 @@ class WorkflowEntry: if not graph: raise ValueError('graph not found in workflow') + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=user_inputs, + environment_variables=workflow.environment_variables, + ) + # init workflow run state graph_engine = GraphEngine( tenant_id=workflow.tenant_id, @@ -104,277 +98,32 @@ class WorkflowEntry: max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME ) - # init workflow run - self._workflow_run_started( - callbacks=callbacks - ) - try: # run workflow - rst = graph_engine.run() - except WorkflowRunFailedError as e: - self._workflow_run_failed( - error=e.error, - callbacks=callbacks - ) - except Exception as e: - self._workflow_run_failed( - error=str(e), - callbacks=callbacks - ) - - # workflow run success - self._workflow_run_success( - callbacks=callbacks - ) - - return rst - - def _run_workflow(self, workflow: Workflow, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[BaseWorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> None: - """ - Run workflow - :param graph_config: workflow graph config - :param workflow_runtime_state: workflow runtime state - :param callbacks: workflow callbacks - :param start_node: force specific start node (gte) - :param end_node: force specific end node (le) - :return: - """ - try: - # init graph - graph = self._init_graph( - graph_config=graph_config - ) - - if not graph: - raise WorkflowRunFailedError( - error='Start node not found in workflow graph.' - ) - - predecessor_node: BaseNode | None = None - has_entry_node = False - max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS - max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME - while True: - # get next nodes - next_nodes = self._get_next_overall_nodes( - workflow_run_state=workflow_run_state, - graph=graph_config, - predecessor_node=predecessor_node, - callbacks=callbacks, - node_start_at=start_node, - node_end_at=end_node - ) - - if not next_nodes: - # reached loop/iteration end or overall end - if current_iteration_node and workflow_run_state.current_iteration_state: - # reached loop/iteration end - # get next iteration - next_iteration = current_iteration_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( + generator = graph_engine.run() + for event in generator: + if callbacks: + for callback in callbacks: + callback.on_event( graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks + graph_init_params=graph_engine.init_params, + graph_runtime_state=graph_engine.graph_runtime_state, + event=event ) - if isinstance(next_iteration, NodeRunResult): - if next_iteration.outputs: - for variable_key, variable_value in next_iteration.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=current_iteration_node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - # iteration has ended - next_nodes = self._get_next_overall_nodes( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=current_iteration_node, - callbacks=callbacks, - node_start_at=start_node, - node_end_at=end_node - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - # continue overall process - elif isinstance(next_iteration, str): - # move to next iteration - next_node_id = next_iteration - # get next id - next_nodes = [self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks)] - - if not next_nodes: - break - - # max steps reached - if workflow_run_state.workflow_node_steps > max_execution_steps: - raise WorkflowRunFailedError('Max steps {} reached.'.format(max_execution_steps)) - - # or max execution time reached - if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): - raise WorkflowRunFailedError('Max execution time {}s reached.'.format(max_execution_time)) - - if len(next_nodes) == 1: - next_node = next_nodes[0] - - # run node - is_continue = self._run_node( - graph=graph, - workflow_run_state=workflow_run_state, - predecessor_node=predecessor_node, - current_node=next_node, - callbacks=callbacks - ) - - if not is_continue: - break - - predecessor_node = next_node - else: - result_dict = {} - - # # new thread - # worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={ - # 'flask_app': current_app._get_current_object(), - # 'graph': graph, - # 'workflow_run_state': workflow_run_state, - # 'predecessor_node': predecessor_node, - # 'next_nodes': next_nodes, - # 'callbacks': callbacks, - # 'result': result_dict - # }) - # - # worker_thread.start() - # worker_thread.join() - - if not workflow_run_state.workflow_node_runs: - raise WorkflowRunFailedError( - error='Start node not found in workflow graph.' - ) - except GenerateTaskStoppedException as e: - return + except GenerateTaskStoppedException: + pass except Exception as e: - raise WorkflowRunFailedError( - error=str(e) - ) - - # def _async_run_nodes(self, flask_app: Flask, - # graph: dict, - # workflow_run_state: WorkflowRunState, - # predecessor_node: Optional[BaseNode], - # next_nodes: list[BaseNode], - # callbacks: list[BaseWorkflowCallback], - # result: dict): - # with flask_app.app_context(): - # try: - # for next_node in next_nodes: - # # TODO run sub workflows - # # run node - # is_continue = self._run_node( - # graph=graph, - # workflow_run_state=workflow_run_state, - # predecessor_node=predecessor_node, - # current_node=next_node, - # callbacks=callbacks - # ) - # - # if not is_continue: - # break - # - # predecessor_node = next_node - # except Exception as e: - # logger.exception("Unknown Error when generating") - # finally: - # db.session.remove() - - def _run_node(self, graph: dict, - workflow_run_state: WorkflowRunState, - predecessor_node: Optional[BaseNode], - current_node: BaseNode, - callbacks: list[BaseWorkflowCallback]) -> bool: - """ - Run node - :param graph: workflow graph - :param workflow_run_state: current workflow run state - :param predecessor_node: predecessor node - :param current_node: current node for run - :param callbacks: workflow callbacks - :return: continue? - """ - # check is already ran - if self._check_node_has_ran(workflow_run_state, current_node.node_id): - return True - - # handle iteration nodes - if isinstance(current_node, BaseIterationNode): - current_iteration_node = current_node - workflow_run_state.current_iteration_state = current_node.run( - variable_pool=workflow_run_state.variable_pool - ) - self._workflow_iteration_started( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, - callbacks=callbacks - ) - predecessor_node = current_node - # move to start node of iteration - current_node_id = current_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(current_node_id, NodeRunResult): - # iteration has ended - current_iteration_node.set_output( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - return True - else: - # fetch next node in iteration - current_node = self._get_node(workflow_run_state, graph, current_node_id, callbacks) - - # run workflow, run multiple target nodes in the future - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=current_node, - predecessor_node=predecessor_node, - callbacks=callbacks - ) - - if current_node.node_type in [NodeType.END]: - return False - - return True + if callbacks: + for callback in callbacks: + callback.on_event( + graph=graph, + graph_init_params=graph_engine.init_params, + graph_runtime_state=graph_engine.graph_runtime_state, + event=GraphRunFailedEvent( + reason=str(e) + ) + ) + return def single_step_run_workflow_node(self, workflow: Workflow, node_id: str, @@ -462,537 +211,6 @@ class WorkflowEntry: return node_instance, node_run_result - def single_step_run_iteration_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: Sequence[BaseWorkflowCallback], - ) -> None: - """ - Single iteration run workflow node - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - for node in nodes: - if node.get('id') == node_id: - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]: - node_config = node - else: - raise ValueError('node id is not an iteration node') - - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - ) - - # variable selector to variable mapping - iteration_nested_nodes = [ - node for node in nodes - if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id - ] - iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - for node_config in iteration_nested_nodes: - # mapping user inputs to variable pool - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - except NotImplementedError: - variable_mapping = {} - - # remove iteration variables - variable_mapping = { - f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() - if value[0] != node_id - } - - # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_nested_node_ids - } - - # append variables to variable pool - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - callbacks=callbacks, - workflow_call_depth=0 - ) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # fetch end node of iteration - end_node_id = None - for edge in graph.get('edges'): - if edge.get('source') == node_id: - end_node_id = edge.get('target') - break - - if not end_node_id: - raise ValueError('end node of iteration not found') - - # init workflow run state - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - workflow_call_depth=0 - ) - - try: - # run workflow - self._run_workflow( - graph_config=workflow.graph, - workflow_runtime_state=workflow_runtime_state, - callbacks=callbacks, - start_node=node_id, - end_node=end_node_id - ) - except WorkflowRunFailedError as e: - self._workflow_run_failed( - error=e.error, - callbacks=callbacks - ) - except Exception as e: - self._workflow_run_failed( - error=str(e), - callbacks=callbacks - ) - - # workflow run success - self._workflow_run_success( - callbacks=callbacks - ) - - def _workflow_run_started(self, callbacks: list[BaseWorkflowCallback] = None) -> None: - """ - Workflow run started - :param callbacks: workflow callbacks - :return: - """ - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - def _workflow_run_success(self, callbacks: Sequence[BaseWorkflowCallback]) -> None: - """ - Workflow run success - :param callbacks: workflow callbacks - :return: - """ - - if callbacks: - for callback in callbacks: - callback.on_workflow_run_succeeded() - - def _workflow_run_failed(self, error: str, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run failed - :param error: error message - :param callbacks: workflow callbacks - :return: - """ - if callbacks: - for callback in callbacks: - callback.on_workflow_run_failed( - error=error - ) - - def _workflow_iteration_started(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - predecessor_node_id: Optional[str] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration started - :param current_iteration_node: current iteration node - :param workflow_run_state: workflow run state - :param callbacks: workflow callbacks - :return: - """ - # get nested nodes - iteration_nested_nodes = [ - node for node in graph.get('nodes') - if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id - ] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_started( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - node_data=current_iteration_node.node_data, - inputs=workflow_run_state.current_iteration_state.inputs, - predecessor_node_id=predecessor_node_id, - metadata=workflow_run_state.current_iteration_state.metadata.model_dump() - ) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - def _workflow_iteration_next(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[BaseWorkflowCallback]) -> None: - """ - Workflow iteration next - :param workflow_run_state: workflow run state - :return: - """ - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_next( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - index=workflow_run_state.current_iteration_state.index, - node_run_index=workflow_run_state.workflow_node_steps, - output=workflow_run_state.current_iteration_state.get_current_output() - ) - # clear ran nodes - workflow_run_state.workflow_node_runs = [ - node_run for node_run in workflow_run_state.workflow_node_runs - if node_run.iteration_node_id != current_iteration_node.node_id - ] - - # clear variables in current iteration - nodes = graph.get('nodes') - nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] - - for node in nodes: - workflow_run_state.variable_pool.remove((node.get('id'),)) - - def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[BaseWorkflowCallback]) -> None: - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_completed( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - outputs={ - 'output': workflow_run_state.current_iteration_state.outputs - } - ) - - def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - callbacks: list[BaseWorkflowCallback], - predecessor_node: Optional[BaseNode] = None, - node_start_at: Optional[str] = None, - node_end_at: Optional[str] = None) -> Optional[BaseNode]: - """ - Get next nodes - multiple target nodes in the future. - :param graph: workflow graph - :param callbacks: workflow callbacks - :param predecessor_node: predecessor node - :param node_start_at: force specific start node - :param node_end_at: force specific end node - :return: target node list - """ - nodes = graph.get('nodes') - if not nodes: - return [] - - if not predecessor_node: - # fetch start node - for node_config in nodes: - node_cls = None - if node_start_at: - if node_config.get('id') == node_start_at: - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - else: - if node_config.get('data', {}).get('type', '') == NodeType.START.value: - node_cls = StartNode - - if node_cls: - return [node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - )] - - return [] - else: - edges = graph.get('edges') - edges = cast(list, edges) - source_node_id = predecessor_node.node_id - - # fetch all outgoing edges from source node - outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] - if not outgoing_edges: - return [] - - # fetch target node ids from outgoing edges - target_edges = [] - source_handle = predecessor_node.node_run_result.edge_source_handle \ - if predecessor_node.node_run_result else None - if source_handle: - for edge in outgoing_edges: - if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle: - target_edges.append(edge) - else: - target_edges = outgoing_edges - - if not target_edges: - return [] - - target_nodes = [] - for target_edge in target_edges: - target_node_id = target_edge.get('target') - - if node_end_at and target_node_id == node_end_at: - continue - - # fetch target node from target node id - target_node_config = None - for node in nodes: - if node.get('id') == target_node_id: - target_node_config = node - break - - if not target_node_config: - continue - - # get next node - target_node_cls = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) - if not target_node_cls: - continue - - target_node = target_node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=target_node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - target_nodes.append(target_node) - - return target_nodes - - def _get_node(self, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - node_id: str, - callbacks: Sequence[WorkflowCallback]): - """ - Get node from graph by node id - """ - nodes = graph.get('nodes') - if not nodes: - return None - - for node_config in nodes: - if node_config.get('id') == node_id: - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes[node_type] - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: - """ - Check timeout - :param start_at: start time - :param max_execution_time: max execution time - :return: - """ - return time.perf_counter() - start_at > max_execution_time - - def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool: - """ - Check node has ran - """ - return bool([ - node_and_result for node_and_result in workflow_run_state.workflow_node_runs - if node_and_result.node_id == node_id - ]) - - def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState, - node: BaseNode, - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_started( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - node_run_index=workflow_run_state.workflow_node_steps, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None - ) - - db.session.close() - - workflow_nodes_and_result = WorkflowNodeAndResult( - node=node, - result=None - ) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - # mark node as running - if workflow_run_state.current_iteration_state: - workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun( - node_id=node.node_id, - iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id - )) - - try: - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) - except GenerateTaskStoppedException as e: - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error='Workflow stopped.' - ) - except Exception as e: - logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) - - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: - # node run failed - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_failed( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - error=node_run_result.error, - inputs=node_run_result.inputs, - outputs=node_run_result.outputs, - process_data=node_run_result.process_data, - ) - - raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - - workflow_nodes_and_result.result = node_run_result - - # node run success - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_succeeded( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - inputs=node_run_result.inputs, - process_data=node_run_result.process_data, - outputs=node_run_result.outputs, - execution_metadata=node_run_result.metadata - ) - - if node_run_result.outputs: - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - - if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - db.session.close() - - def _append_variables_recursively(self, variable_pool: VariablePool, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): - """ - Append variables recursively - :param variable_pool: variable pool - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_pool.add( - [node_id] + variable_key_list, variable_value - ) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - variable_pool=variable_pool, - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value - ) - @classmethod def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]: """ @@ -1064,9 +282,4 @@ class WorkflowEntry: value = new_value # append variable and value to variable pool - variable_pool.add([variable_node_id]+variable_key_list, value) - - -class WorkflowRunFailedError(Exception): - def __init__(self, error: str): - self.error = error + variable_pool.add([variable_node_id] + variable_key_list, value) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 0d4dc8b10c..ba29431bf2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -206,4 +206,4 @@ def test_run(): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} - assert count == 15 + assert count == 20