This commit is contained in:
takatost 2024-06-27 05:30:38 +08:00
parent aaa98c76d5
commit 1d8ecac093
2 changed files with 34 additions and 21 deletions

View File

@ -1,8 +1,9 @@
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
@ -20,7 +21,7 @@ class RuntimeNode(BaseModel):
FAILED = "failed"
PAUSED = "paused"
id: str
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random id for current runtime node"""
graph_node: GraphNode
@ -97,4 +98,4 @@ class WorkflowRuntimeState(BaseModel):
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -13,7 +13,9 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
@ -35,6 +37,7 @@ from extensions.ext_database import db
from models.workflow import (
Workflow,
WorkflowNodeExecutionStatus,
WorkflowType,
)
node_classes = {
@ -80,17 +83,17 @@ class WorkflowEngineManager:
:param variable_pool: variable pool
"""
# fetch workflow graph
graph_dict = workflow.graph_dict
if not graph_dict:
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError('workflow graph not found')
if 'nodes' not in graph_dict or 'edges' not in graph_dict:
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_dict.get('nodes'), list):
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_dict.get('edges'), list):
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init variable pool
@ -106,15 +109,19 @@ class WorkflowEngineManager:
if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init workflow run state
workflow_run_state = WorkflowRunState(
workflow=workflow,
start_at=time.perf_counter(),
variable_pool=variable_pool,
# init workflow runtime state
workflow_runtime_state = WorkflowRuntimeState(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
user_id=user_id,
user_from=user_from,
variable_pool=variable_pool,
invoke_from=invoke_from,
workflow_call_depth=call_depth
graph=graph_config,
call_depth=call_depth,
start_at=time.perf_counter()
)
# init workflow run
@ -124,8 +131,8 @@ class WorkflowEngineManager:
# run workflow
self._run_workflow(
graph=graph_dict,
workflow_run_state=workflow_run_state,
graph_config=graph_config,
workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks,
)
@ -134,21 +141,26 @@ class WorkflowEngineManager:
callbacks=callbacks
)
def _run_workflow(self, graph: dict,
workflow_run_state: WorkflowRunState,
def _run_workflow(self, graph_config: dict,
workflow_runtime_state: WorkflowRuntimeState,
callbacks: list[BaseWorkflowCallback],
start_node: Optional[str] = None,
end_node: Optional[str] = None) -> None:
"""
Run workflow
:param graph: workflow graph
:param workflow_run_state: workflow run state
:param graph_config: workflow graph config
:param workflow_runtime_state: workflow runtime state
:param callbacks: workflow callbacks
:param start_node: force specific start node (gte)
:param end_node: force specific end node (le)
:return:
"""
try:
# init graph
graph = Graph(
graph_config=graph_config
)
predecessor_node: Optional[BaseNode] = None
current_iteration_node: Optional[BaseIterationNode] = None
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
@ -159,7 +171,7 @@ class WorkflowEngineManager:
# get next nodes
next_nodes = self._get_next_overall_nodes(
workflow_run_state=workflow_run_state,
graph=graph,
graph=graph_config,
predecessor_node=predecessor_node,
callbacks=callbacks,
node_start_at=start_node,