From 1d8ecac093f167b6105e2880983b5db551252857 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 27 Jun 2024 05:30:38 +0800 Subject: [PATCH] save --- .../workflow_runtime_state_entities.py | 7 +-- api/core/workflow/workflow_engine_manager.py | 48 ++++++++++++------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/api/core/workflow/entities/workflow_runtime_state_entities.py b/api/core/workflow/entities/workflow_runtime_state_entities.py index 39141c52e7..cee39a1a99 100644 --- a/api/core/workflow/entities/workflow_runtime_state_entities.py +++ b/api/core/workflow/entities/workflow_runtime_state_entities.py @@ -1,8 +1,9 @@ +import uuid from datetime import datetime, timezone from enum import Enum from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult @@ -20,7 +21,7 @@ class RuntimeNode(BaseModel): FAILED = "failed" PAUSED = "paused" - id: str + id: str = Field(default_factory=lambda: str(uuid.uuid4())) """random id for current runtime node""" graph_node: GraphNode @@ -97,4 +98,4 @@ class WorkflowRuntimeState(BaseModel): total_tokens: int = 0 node_run_steps: int = 0 - runtime_graph: RuntimeGraph + runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 4007af85a1..544438a5bf 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -13,7 +13,9 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState +from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph import Graph from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode @@ -35,6 +37,7 @@ from extensions.ext_database import db from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, + WorkflowType, ) node_classes = { @@ -80,17 +83,17 @@ class WorkflowEngineManager: :param variable_pool: variable pool """ # fetch workflow graph - graph_dict = workflow.graph_dict - if not graph_dict: + graph_config = workflow.graph_dict + if not graph_config: raise ValueError('workflow graph not found') - if 'nodes' not in graph_dict or 'edges' not in graph_dict: + if 'nodes' not in graph_config or 'edges' not in graph_config: raise ValueError('nodes or edges not found in workflow graph') - if not isinstance(graph_dict.get('nodes'), list): + if not isinstance(graph_config.get('nodes'), list): raise ValueError('nodes in workflow graph must be a list') - if not isinstance(graph_dict.get('edges'), list): + if not isinstance(graph_config.get('edges'), list): raise ValueError('edges in workflow graph must be a list') # init variable pool @@ -106,15 +109,19 @@ class WorkflowEngineManager: if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) - # init workflow run state - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, + # init workflow runtime state + workflow_runtime_state = WorkflowRuntimeState( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), user_id=user_id, user_from=user_from, + variable_pool=variable_pool, invoke_from=invoke_from, - workflow_call_depth=call_depth + graph=graph_config, + call_depth=call_depth, + start_at=time.perf_counter() ) # init workflow run @@ -124,8 +131,8 @@ class WorkflowEngineManager: # run workflow self._run_workflow( - graph=graph_dict, - workflow_run_state=workflow_run_state, + graph_config=graph_config, + workflow_runtime_state=workflow_runtime_state, callbacks=callbacks, ) @@ -134,21 +141,26 @@ class WorkflowEngineManager: callbacks=callbacks ) - def _run_workflow(self, graph: dict, - workflow_run_state: WorkflowRunState, + def _run_workflow(self, graph_config: dict, + workflow_runtime_state: WorkflowRuntimeState, callbacks: list[BaseWorkflowCallback], start_node: Optional[str] = None, end_node: Optional[str] = None) -> None: """ Run workflow - :param graph: workflow graph - :param workflow_run_state: workflow run state + :param graph_config: workflow graph config + :param workflow_runtime_state: workflow runtime state :param callbacks: workflow callbacks :param start_node: force specific start node (gte) :param end_node: force specific end node (le) :return: """ try: + # init graph + graph = Graph( + graph_config=graph_config + ) + predecessor_node: Optional[BaseNode] = None current_iteration_node: Optional[BaseIterationNode] = None max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") @@ -159,7 +171,7 @@ class WorkflowEngineManager: # get next nodes next_nodes = self._get_next_overall_nodes( workflow_run_state=workflow_run_state, - graph=graph, + graph=graph_config, predecessor_node=predecessor_node, callbacks=callbacks, node_start_at=start_node,