mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 09:36:40 +08:00
completed workflow engine main logic
This commit is contained in:
parent
dd50deaa43
commit
7d28fe8ea5
@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
workflow_engine_manager = WorkflowEngineManager()
|
workflow_engine_manager = WorkflowEngineManager()
|
||||||
workflow_engine_manager.run_workflow(
|
workflow_engine_manager.run_workflow(
|
||||||
app_model=app_record,
|
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
@ -94,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
SystemVariable.FILES: files,
|
SystemVariable.FILES: files,
|
||||||
SystemVariable.CONVERSATION: conversation.id,
|
SystemVariable.CONVERSATION: conversation.id,
|
||||||
},
|
},
|
||||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)],
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||||
|
|||||||
@ -253,8 +253,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
'error': workflow_run.error,
|
'error': workflow_run.error,
|
||||||
'elapsed_time': workflow_run.elapsed_time,
|
'elapsed_time': workflow_run.elapsed_time,
|
||||||
'total_tokens': workflow_run.total_tokens,
|
'total_tokens': workflow_run.total_tokens,
|
||||||
'total_price': workflow_run.total_price,
|
|
||||||
'currency': workflow_run.currency,
|
|
||||||
'total_steps': workflow_run.total_steps,
|
'total_steps': workflow_run.total_steps,
|
||||||
'created_at': int(workflow_run.created_at.timestamp()),
|
'created_at': int(workflow_run.created_at.timestamp()),
|
||||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||||
|
|
||||||
|
|
||||||
@ -43,3 +43,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||||||
workflow_node_execution_id=workflow_node_execution.id,
|
workflow_node_execution_id=workflow_node_execution.id,
|
||||||
pub_from=PublishFrom.TASK_PIPELINE
|
pub_from=PublishFrom.TASK_PIPELINE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_text_chunk(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Publish text chunk
|
||||||
|
"""
|
||||||
|
self._queue_manager.publish_text_chunk(
|
||||||
|
text=text,
|
||||||
|
pub_from=PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
|
|||||||
@ -31,3 +31,11 @@ class BaseWorkflowCallback:
|
|||||||
Workflow node execute finished
|
Workflow node execute finished
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_text_chunk(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Publish text chunk
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -1,4 +1,9 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class NodeType(Enum):
|
class NodeType(Enum):
|
||||||
@ -39,3 +44,19 @@ class SystemVariable(Enum):
|
|||||||
QUERY = 'query'
|
QUERY = 'query'
|
||||||
FILES = 'files'
|
FILES = 'files'
|
||||||
CONVERSATION = 'conversation'
|
CONVERSATION = 'conversation'
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRunResult(BaseModel):
|
||||||
|
"""
|
||||||
|
Node Run Result.
|
||||||
|
"""
|
||||||
|
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||||
|
|
||||||
|
inputs: Optional[dict] = None # node inputs
|
||||||
|
process_data: Optional[dict] = None # process data
|
||||||
|
outputs: Optional[dict] = None # node outputs
|
||||||
|
metadata: Optional[dict] = None # node metadata
|
||||||
|
|
||||||
|
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||||
|
|
||||||
|
error: Optional[str] = None # error message if status is failed
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||||
|
|
||||||
@ -10,7 +8,10 @@ class WorkflowRunState:
|
|||||||
variable_pool: VariablePool
|
variable_pool: VariablePool
|
||||||
|
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
total_price: Decimal = Decimal(0)
|
|
||||||
currency: str = "USD"
|
|
||||||
|
|
||||||
workflow_node_executions: list[WorkflowNodeExecution] = []
|
workflow_node_executions: list[WorkflowNodeExecution] = []
|
||||||
|
|
||||||
|
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
|
||||||
|
self.workflow_run = workflow_run
|
||||||
|
self.start_at = start_at
|
||||||
|
self.variable_pool = variable_pool
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class BaseNode:
|
class BaseNode:
|
||||||
@ -13,17 +14,23 @@ class BaseNode:
|
|||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
node_data: BaseNodeData
|
node_data: BaseNodeData
|
||||||
|
node_run_result: Optional[NodeRunResult] = None
|
||||||
|
|
||||||
def __init__(self, config: dict) -> None:
|
stream_output_supported: bool = False
|
||||||
|
callbacks: list[BaseWorkflowCallback]
|
||||||
|
|
||||||
|
def __init__(self, config: dict,
|
||||||
|
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||||
self.node_id = config.get("id")
|
self.node_id = config.get("id")
|
||||||
if not self.node_id:
|
if not self.node_id:
|
||||||
raise ValueError("Node ID is required.")
|
raise ValueError("Node ID is required.")
|
||||||
|
|
||||||
self.node_data = self._node_data_cls(**config.get("data", {}))
|
self.node_data = self._node_data_cls(**config.get("data", {}))
|
||||||
|
self.callbacks = callbacks or []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||||
run_args: Optional[dict] = None) -> dict:
|
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
:param variable_pool: variable pool
|
:param variable_pool: variable pool
|
||||||
@ -33,22 +40,41 @@ class BaseNode:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def run(self, variable_pool: Optional[VariablePool] = None,
|
def run(self, variable_pool: Optional[VariablePool] = None,
|
||||||
run_args: Optional[dict] = None,
|
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||||
callbacks: list[BaseWorkflowCallback] = None) -> dict:
|
|
||||||
"""
|
"""
|
||||||
Run node entry
|
Run node entry
|
||||||
:param variable_pool: variable pool
|
:param variable_pool: variable pool
|
||||||
:param run_args: run args
|
:param run_args: run args
|
||||||
:param callbacks: callbacks
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if variable_pool is None and run_args is None:
|
if variable_pool is None and run_args is None:
|
||||||
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
|
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
|
||||||
|
|
||||||
return self._run(
|
try:
|
||||||
variable_pool=variable_pool,
|
result = self._run(
|
||||||
run_args=run_args
|
variable_pool=variable_pool,
|
||||||
)
|
run_args=run_args
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# process unhandled exception
|
||||||
|
result = NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.node_run_result = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
def publish_text_chunk(self, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Publish text chunk
|
||||||
|
:param text: chunk text
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.stream_output_supported:
|
||||||
|
if self.callbacks:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_text_chunk(text)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
from core.workflow.entities.workflow_entities import WorkflowRunState
|
from core.workflow.entities.workflow_entities import WorkflowRunState
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.code.code_node import CodeNode
|
from core.workflow.nodes.code.code_node import CodeNode
|
||||||
@ -31,6 +32,7 @@ from models.workflow import (
|
|||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
WorkflowRunTriggeredFrom,
|
WorkflowRunTriggeredFrom,
|
||||||
|
WorkflowType,
|
||||||
)
|
)
|
||||||
|
|
||||||
node_classes = {
|
node_classes = {
|
||||||
@ -120,8 +122,7 @@ class WorkflowEngineManager:
|
|||||||
|
|
||||||
return default_config
|
return default_config
|
||||||
|
|
||||||
def run_workflow(self, app_model: App,
|
def run_workflow(self, workflow: Workflow,
|
||||||
workflow: Workflow,
|
|
||||||
triggered_from: WorkflowRunTriggeredFrom,
|
triggered_from: WorkflowRunTriggeredFrom,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
@ -129,7 +130,6 @@ class WorkflowEngineManager:
|
|||||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Run workflow
|
Run workflow
|
||||||
:param app_model: App instance
|
|
||||||
:param workflow: Workflow instance
|
:param workflow: Workflow instance
|
||||||
:param triggered_from: triggered from
|
:param triggered_from: triggered from
|
||||||
:param user: account or end user
|
:param user: account or end user
|
||||||
@ -143,13 +143,23 @@ class WorkflowEngineManager:
|
|||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError('workflow graph not found')
|
raise ValueError('workflow graph not found')
|
||||||
|
|
||||||
|
if 'nodes' not in graph or 'edges' not in graph:
|
||||||
|
raise ValueError('nodes or edges not found in workflow graph')
|
||||||
|
|
||||||
|
if isinstance(graph.get('nodes'), list):
|
||||||
|
raise ValueError('nodes in workflow graph must be a list')
|
||||||
|
|
||||||
|
if isinstance(graph.get('edges'), list):
|
||||||
|
raise ValueError('edges in workflow graph must be a list')
|
||||||
|
|
||||||
# init workflow run
|
# init workflow run
|
||||||
workflow_run = self._init_workflow_run(
|
workflow_run = self._init_workflow_run(
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
triggered_from=triggered_from,
|
triggered_from=triggered_from,
|
||||||
user=user,
|
user=user,
|
||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
system_inputs=system_inputs
|
system_inputs=system_inputs,
|
||||||
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
@ -161,44 +171,54 @@ class WorkflowEngineManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if callbacks:
|
# fetch predecessor node ids before end node (include: llm, direct answer)
|
||||||
for callback in callbacks:
|
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
|
||||||
callback.on_workflow_run_started(workflow_run)
|
|
||||||
|
|
||||||
# fetch start node
|
|
||||||
start_node = self._get_entry_node(graph)
|
|
||||||
if not start_node:
|
|
||||||
self._workflow_run_failed(
|
|
||||||
workflow_run_state=workflow_run_state,
|
|
||||||
error='Start node not found in workflow graph',
|
|
||||||
callbacks=callbacks
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
predecessor_node = None
|
predecessor_node = None
|
||||||
current_node = start_node
|
|
||||||
while True:
|
while True:
|
||||||
# run workflow
|
# get next node, multiple target nodes in the future
|
||||||
self._run_workflow_node(
|
next_node = self._get_next_node(
|
||||||
workflow_run_state=workflow_run_state,
|
graph=graph,
|
||||||
node=current_node,
|
|
||||||
predecessor_node=predecessor_node,
|
predecessor_node=predecessor_node,
|
||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_node.node_type == NodeType.END:
|
if not next_node:
|
||||||
break
|
break
|
||||||
|
|
||||||
# todo fetch next node until end node finished or no next node
|
# check if node is streamable
|
||||||
current_node = None
|
if next_node.node_id in streamable_node_ids:
|
||||||
|
next_node.stream_output_supported = True
|
||||||
|
|
||||||
if not current_node:
|
# max steps 30 reached
|
||||||
break
|
if len(workflow_run_state.workflow_node_executions) > 30:
|
||||||
|
raise ValueError('Max steps 30 reached.')
|
||||||
|
|
||||||
predecessor_node = current_node
|
|
||||||
# or max steps 30 reached
|
|
||||||
# or max execution time 10min reached
|
# or max execution time 10min reached
|
||||||
|
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600):
|
||||||
|
raise ValueError('Max execution time 10min reached.')
|
||||||
|
|
||||||
|
# run workflow, run multiple target nodes in the future
|
||||||
|
self._run_workflow_node(
|
||||||
|
workflow_run_state=workflow_run_state,
|
||||||
|
node=next_node,
|
||||||
|
predecessor_node=predecessor_node,
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
if next_node.node_type == NodeType.END:
|
||||||
|
break
|
||||||
|
|
||||||
|
predecessor_node = next_node
|
||||||
|
|
||||||
|
if not predecessor_node and not next_node:
|
||||||
|
self._workflow_run_failed(
|
||||||
|
workflow_run_state=workflow_run_state,
|
||||||
|
error='Start node not found in workflow graph.',
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._workflow_run_failed(
|
self._workflow_run_failed(
|
||||||
workflow_run_state=workflow_run_state,
|
workflow_run_state=workflow_run_state,
|
||||||
@ -213,11 +233,40 @@ class WorkflowEngineManager:
|
|||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
|
||||||
|
"""
|
||||||
|
Fetch streamable node ids
|
||||||
|
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||||
|
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||||
|
|
||||||
|
:param workflow: Workflow instance
|
||||||
|
:param graph: workflow graph
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
workflow_type = WorkflowType.value_of(workflow.type)
|
||||||
|
|
||||||
|
streamable_node_ids = []
|
||||||
|
end_node_ids = []
|
||||||
|
for node_config in graph.get('nodes'):
|
||||||
|
if node_config.get('type') == NodeType.END.value:
|
||||||
|
if workflow_type == WorkflowType.WORKFLOW:
|
||||||
|
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
||||||
|
end_node_ids.append(node_config.get('id'))
|
||||||
|
else:
|
||||||
|
end_node_ids.append(node_config.get('id'))
|
||||||
|
|
||||||
|
for edge_config in graph.get('edges'):
|
||||||
|
if edge_config.get('target') in end_node_ids:
|
||||||
|
streamable_node_ids.append(edge_config.get('source'))
|
||||||
|
|
||||||
|
return streamable_node_ids
|
||||||
|
|
||||||
def _init_workflow_run(self, workflow: Workflow,
|
def _init_workflow_run(self, workflow: Workflow,
|
||||||
triggered_from: WorkflowRunTriggeredFrom,
|
triggered_from: WorkflowRunTriggeredFrom,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
system_inputs: Optional[dict] = None,
|
||||||
|
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
|
||||||
"""
|
"""
|
||||||
Init workflow run
|
Init workflow run
|
||||||
:param workflow: Workflow instance
|
:param workflow: Workflow instance
|
||||||
@ -225,6 +274,7 @@ class WorkflowEngineManager:
|
|||||||
:param user: account or end user
|
:param user: account or end user
|
||||||
:param user_inputs: user variables inputs
|
:param user_inputs: user variables inputs
|
||||||
:param system_inputs: system inputs, like: query, files
|
:param system_inputs: system inputs, like: query, files
|
||||||
|
:param callbacks: workflow callbacks
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@ -260,6 +310,39 @@ class WorkflowEngineManager:
|
|||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if callbacks:
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_workflow_run_started(workflow_run)
|
||||||
|
|
||||||
|
return workflow_run
|
||||||
|
|
||||||
|
def _workflow_run_success(self, workflow_run_state: WorkflowRunState,
|
||||||
|
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
|
||||||
|
"""
|
||||||
|
Workflow run success
|
||||||
|
:param workflow_run_state: workflow run state
|
||||||
|
:param callbacks: workflow callbacks
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
workflow_run = workflow_run_state.workflow_run
|
||||||
|
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||||
|
|
||||||
|
# fetch last workflow_node_executions
|
||||||
|
last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1]
|
||||||
|
if last_workflow_node_execution:
|
||||||
|
workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs)
|
||||||
|
|
||||||
|
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
|
||||||
|
workflow_run.total_tokens = workflow_run_state.total_tokens
|
||||||
|
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
|
||||||
|
workflow_run.finished_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
if callbacks:
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_workflow_run_finished(workflow_run)
|
||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _workflow_run_failed(self, workflow_run_state: WorkflowRunState,
|
def _workflow_run_failed(self, workflow_run_state: WorkflowRunState,
|
||||||
@ -277,9 +360,8 @@ class WorkflowEngineManager:
|
|||||||
workflow_run.error = error
|
workflow_run.error = error
|
||||||
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
|
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
|
||||||
workflow_run.total_tokens = workflow_run_state.total_tokens
|
workflow_run.total_tokens = workflow_run_state.total_tokens
|
||||||
workflow_run.total_price = workflow_run_state.total_price
|
|
||||||
workflow_run.currency = workflow_run_state.currency
|
|
||||||
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
|
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
|
||||||
|
workflow_run.finished_at = datetime.utcnow()
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@ -289,21 +371,77 @@ class WorkflowEngineManager:
|
|||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
def _get_entry_node(self, graph: dict) -> Optional[StartNode]:
|
def _get_next_node(self, graph: dict,
|
||||||
|
predecessor_node: Optional[BaseNode] = None,
|
||||||
|
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
|
||||||
"""
|
"""
|
||||||
Get entry node
|
Get next node
|
||||||
|
multiple target nodes in the future.
|
||||||
:param graph: workflow graph
|
:param graph: workflow graph
|
||||||
|
:param predecessor_node: predecessor node
|
||||||
|
:param callbacks: workflow callbacks
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
nodes = graph.get('nodes')
|
nodes = graph.get('nodes')
|
||||||
if not nodes:
|
if not nodes:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for node_config in nodes.items():
|
if not predecessor_node:
|
||||||
if node_config.get('type') == NodeType.START.value:
|
for node_config in nodes:
|
||||||
return StartNode(config=node_config)
|
if node_config.get('type') == NodeType.START.value:
|
||||||
|
return StartNode(config=node_config)
|
||||||
|
else:
|
||||||
|
edges = graph.get('edges')
|
||||||
|
source_node_id = predecessor_node.node_id
|
||||||
|
|
||||||
return None
|
# fetch all outgoing edges from source node
|
||||||
|
outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
|
||||||
|
if not outgoing_edges:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# fetch target node id from outgoing edges
|
||||||
|
outgoing_edge = None
|
||||||
|
source_handle = predecessor_node.node_run_result.edge_source_handle
|
||||||
|
if source_handle:
|
||||||
|
for edge in outgoing_edges:
|
||||||
|
if edge.get('source_handle') and edge.get('source_handle') == source_handle:
|
||||||
|
outgoing_edge = edge
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
outgoing_edge = outgoing_edges[0]
|
||||||
|
|
||||||
|
if not outgoing_edge:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target_node_id = outgoing_edge.get('target')
|
||||||
|
|
||||||
|
# fetch target node from target node id
|
||||||
|
target_node_config = None
|
||||||
|
for node in nodes:
|
||||||
|
if node.get('id') == target_node_id:
|
||||||
|
target_node_config = node
|
||||||
|
break
|
||||||
|
|
||||||
|
if not target_node_config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get next node
|
||||||
|
target_node = node_classes.get(NodeType.value_of(target_node_config.get('type')))
|
||||||
|
|
||||||
|
return target_node(
|
||||||
|
config=target_node_config,
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check timeout
|
||||||
|
:param start_at: start time
|
||||||
|
:param max_execution_time: max execution time
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# TODO check queue is stopped
|
||||||
|
return time.perf_counter() - start_at > max_execution_time
|
||||||
|
|
||||||
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
|
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
|
||||||
node: BaseNode,
|
node: BaseNode,
|
||||||
@ -320,28 +458,41 @@ class WorkflowEngineManager:
|
|||||||
# add to workflow node executions
|
# add to workflow node executions
|
||||||
workflow_run_state.workflow_node_executions.append(workflow_node_execution)
|
workflow_run_state.workflow_node_executions.append(workflow_node_execution)
|
||||||
|
|
||||||
try:
|
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
node_run_result = node.run(
|
||||||
node_run_result = node.run(
|
variable_pool=workflow_run_state.variable_pool
|
||||||
variable_pool=workflow_run_state.variable_pool,
|
)
|
||||||
callbacks=callbacks
|
|
||||||
)
|
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
except Exception as e:
|
|
||||||
# node run failed
|
# node run failed
|
||||||
self._workflow_node_execution_failed(
|
self._workflow_node_execution_failed(
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
error=str(e),
|
start_at=start_at,
|
||||||
|
error=node_run_result.error,
|
||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
raise
|
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
|
||||||
|
|
||||||
# node run success
|
# node run success
|
||||||
self._workflow_node_execution_success(
|
self._workflow_node_execution_success(
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
start_at=start_at,
|
||||||
result=node_run_result,
|
result=node_run_result,
|
||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for variable_key, variable_value in node_run_result.outputs.items():
|
||||||
|
# append variables to variable pool recursively
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=workflow_run_state.variable_pool,
|
||||||
|
node_id=node.node_id,
|
||||||
|
variable_key_list=[variable_key],
|
||||||
|
variable_value=variable_value
|
||||||
|
)
|
||||||
|
|
||||||
|
if node_run_result.metadata.get('total_tokens'):
|
||||||
|
workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens'))
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState,
|
def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState,
|
||||||
@ -384,3 +535,86 @@ class WorkflowEngineManager:
|
|||||||
callback.on_workflow_node_execute_started(workflow_node_execution)
|
callback.on_workflow_node_execute_started(workflow_node_execution)
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||||
|
start_at: float,
|
||||||
|
result: NodeRunResult,
|
||||||
|
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
|
||||||
|
"""
|
||||||
|
Workflow node execution success
|
||||||
|
:param workflow_node_execution: workflow node execution
|
||||||
|
:param start_at: start time
|
||||||
|
:param result: node run result
|
||||||
|
:param callbacks: workflow callbacks
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
|
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||||
|
workflow_node_execution.inputs = json.dumps(result.inputs)
|
||||||
|
workflow_node_execution.process_data = json.dumps(result.process_data)
|
||||||
|
workflow_node_execution.outputs = json.dumps(result.outputs)
|
||||||
|
workflow_node_execution.execution_metadata = json.dumps(result.metadata)
|
||||||
|
workflow_node_execution.finished_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
if callbacks:
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_workflow_node_execute_finished(workflow_node_execution)
|
||||||
|
|
||||||
|
return workflow_node_execution
|
||||||
|
|
||||||
|
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||||
|
start_at: float,
|
||||||
|
error: str,
|
||||||
|
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
|
||||||
|
"""
|
||||||
|
Workflow node execution failed
|
||||||
|
:param workflow_node_execution: workflow node execution
|
||||||
|
:param start_at: start time
|
||||||
|
:param error: error message
|
||||||
|
:param callbacks: workflow callbacks
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
|
workflow_node_execution.error = error
|
||||||
|
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||||
|
workflow_node_execution.finished_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
if callbacks:
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_workflow_node_execute_finished(workflow_node_execution)
|
||||||
|
|
||||||
|
return workflow_node_execution
|
||||||
|
|
||||||
|
def _append_variables_recursively(self, variable_pool: VariablePool,
|
||||||
|
node_id: str,
|
||||||
|
variable_key_list: list[str],
|
||||||
|
variable_value: VariableValue):
|
||||||
|
"""
|
||||||
|
Append variables recursively
|
||||||
|
:param variable_pool: variable pool
|
||||||
|
:param node_id: node id
|
||||||
|
:param variable_key_list: variable key list
|
||||||
|
:param variable_value: variable value
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
variable_pool.append_variable(
|
||||||
|
node_id=node_id,
|
||||||
|
variable_key_list=variable_key_list,
|
||||||
|
value=variable_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# if variable_value is a dict, then recursively append variables
|
||||||
|
if isinstance(variable_value, dict):
|
||||||
|
for key, value in variable_value.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = variable_key_list + [key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_id=node_id,
|
||||||
|
variable_key_list=new_key_list,
|
||||||
|
variable_value=value
|
||||||
|
)
|
||||||
|
|||||||
@ -11,8 +11,6 @@ workflow_run_for_log_fields = {
|
|||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
"elapsed_time": fields.Float,
|
"elapsed_time": fields.Float,
|
||||||
"total_tokens": fields.Integer,
|
"total_tokens": fields.Integer,
|
||||||
"total_price": fields.Float,
|
|
||||||
"currency": fields.String,
|
|
||||||
"total_steps": fields.Integer,
|
"total_steps": fields.Integer,
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": TimestampField
|
"finished_at": TimestampField
|
||||||
@ -29,8 +27,6 @@ workflow_run_for_list_fields = {
|
|||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
"elapsed_time": fields.Float,
|
"elapsed_time": fields.Float,
|
||||||
"total_tokens": fields.Integer,
|
"total_tokens": fields.Integer,
|
||||||
"total_price": fields.Float,
|
|
||||||
"currency": fields.String,
|
|
||||||
"total_steps": fields.Integer,
|
"total_steps": fields.Integer,
|
||||||
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
|
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
@ -56,8 +52,6 @@ workflow_run_detail_fields = {
|
|||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
"elapsed_time": fields.Float,
|
"elapsed_time": fields.Float,
|
||||||
"total_tokens": fields.Integer,
|
"total_tokens": fields.Integer,
|
||||||
"total_price": fields.Float,
|
|
||||||
"currency": fields.String,
|
|
||||||
"total_steps": fields.Integer,
|
"total_steps": fields.Integer,
|
||||||
"created_by_role": fields.String,
|
"created_by_role": fields.String,
|
||||||
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
|
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
|
||||||
|
|||||||
@ -216,8 +216,6 @@ class WorkflowRun(db.Model):
|
|||||||
- error (string) `optional` Error reason
|
- error (string) `optional` Error reason
|
||||||
- elapsed_time (float) `optional` Time consumption (s)
|
- elapsed_time (float) `optional` Time consumption (s)
|
||||||
- total_tokens (int) `optional` Total tokens used
|
- total_tokens (int) `optional` Total tokens used
|
||||||
- total_price (decimal) `optional` Total cost
|
|
||||||
- currency (string) `optional` Currency, such as USD / RMB
|
|
||||||
- total_steps (int) Total steps (redundant), default 0
|
- total_steps (int) Total steps (redundant), default 0
|
||||||
- created_by_role (string) Creator role
|
- created_by_role (string) Creator role
|
||||||
|
|
||||||
@ -251,8 +249,6 @@ class WorkflowRun(db.Model):
|
|||||||
error = db.Column(db.Text)
|
error = db.Column(db.Text)
|
||||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0'))
|
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0'))
|
||||||
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||||
total_price = db.Column(db.Numeric(10, 7))
|
|
||||||
currency = db.Column(db.String(255))
|
|
||||||
total_steps = db.Column(db.Integer, server_default=db.text('0'))
|
total_steps = db.Column(db.Integer, server_default=db.text('0'))
|
||||||
created_by_role = db.Column(db.String(255), nullable=False)
|
created_by_role = db.Column(db.String(255), nullable=False)
|
||||||
created_by = db.Column(UUID, nullable=False)
|
created_by = db.Column(UUID, nullable=False)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user