mirror of https://github.com/langgenius/dify.git
save
This commit is contained in:
parent
aaa98c76d5
commit
1d8ecac093
|
|
@ -1,8 +1,9 @@
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
|
@ -20,7 +21,7 @@ class RuntimeNode(BaseModel):
|
|||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
id: str
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random id for current runtime node"""
|
||||
|
||||
graph_node: GraphNode
|
||||
|
|
@ -97,4 +98,4 @@ class WorkflowRuntimeState(BaseModel):
|
|||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
|
||||
runtime_graph: RuntimeGraph
|
||||
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,9 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
|||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
|
|
@ -35,6 +37,7 @@ from extensions.ext_database import db
|
|||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
|
|
@ -80,17 +83,17 @@ class WorkflowEngineManager:
|
|||
:param variable_pool: variable pool
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_dict = workflow.graph_dict
|
||||
if not graph_dict:
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
if 'nodes' not in graph_dict or 'edges' not in graph_dict:
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_dict.get('nodes'), list):
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_dict.get('edges'), list):
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# init variable pool
|
||||
|
|
@ -106,15 +109,19 @@ class WorkflowEngineManager:
|
|||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
|
||||
# init workflow run state
|
||||
workflow_run_state = WorkflowRunState(
|
||||
workflow=workflow,
|
||||
start_at=time.perf_counter(),
|
||||
variable_pool=variable_pool,
|
||||
# init workflow runtime state
|
||||
workflow_runtime_state = WorkflowRuntimeState(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
variable_pool=variable_pool,
|
||||
invoke_from=invoke_from,
|
||||
workflow_call_depth=call_depth
|
||||
graph=graph_config,
|
||||
call_depth=call_depth,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
|
|
@ -124,8 +131,8 @@ class WorkflowEngineManager:
|
|||
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
graph=graph_dict,
|
||||
workflow_run_state=workflow_run_state,
|
||||
graph_config=graph_config,
|
||||
workflow_runtime_state=workflow_runtime_state,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
|
@ -134,21 +141,26 @@ class WorkflowEngineManager:
|
|||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _run_workflow(self, graph: dict,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
def _run_workflow(self, graph_config: dict,
|
||||
workflow_runtime_state: WorkflowRuntimeState,
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
start_node: Optional[str] = None,
|
||||
end_node: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run workflow
|
||||
:param graph: workflow graph
|
||||
:param workflow_run_state: workflow run state
|
||||
:param graph_config: workflow graph config
|
||||
:param workflow_runtime_state: workflow runtime state
|
||||
:param callbacks: workflow callbacks
|
||||
:param start_node: force specific start node (gte)
|
||||
:param end_node: force specific end node (le)
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# init graph
|
||||
graph = Graph(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
predecessor_node: Optional[BaseNode] = None
|
||||
current_iteration_node: Optional[BaseIterationNode] = None
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
|
|
@ -159,7 +171,7 @@ class WorkflowEngineManager:
|
|||
# get next nodes
|
||||
next_nodes = self._get_next_overall_nodes(
|
||||
workflow_run_state=workflow_run_state,
|
||||
graph=graph,
|
||||
graph=graph_config,
|
||||
predecessor_node=predecessor_node,
|
||||
callbacks=callbacks,
|
||||
node_start_at=start_node,
|
||||
|
|
|
|||
Loading…
Reference in New Issue