mirror of https://github.com/langgenius/dify.git
completed workflow engine main logic
This commit is contained in:
parent
dd50deaa43
commit
7d28fe8ea5
|
|
@ -83,7 +83,6 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
app_model=app_record,
|
||||
workflow=workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
|
|
@ -94,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
},
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)],
|
||||
)
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
|
|
|
|||
|
|
@ -253,8 +253,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
'error': workflow_run.error,
|
||||
'elapsed_time': workflow_run.elapsed_time,
|
||||
'total_tokens': workflow_run.total_tokens,
|
||||
'total_price': workflow_run.total_price,
|
||||
'currency': workflow_run.currency,
|
||||
'total_steps': workflow_run.total_steps,
|
||||
'created_at': int(workflow_run.created_at.timestamp()),
|
||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
||||
|
|
@ -43,3 +43,12 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
def on_text_chunk(self, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish_text_chunk(
|
||||
text=text,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,3 +31,11 @@ class BaseWorkflowCallback:
|
|||
Workflow node execute finished
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_text_chunk(self, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -1,4 +1,9 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
|
|
@ -39,3 +44,19 @@ class SystemVariable(Enum):
|
|||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION = 'conversation'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[dict] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[dict] = None # node outputs
|
||||
metadata: Optional[dict] = None # node metadata
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
|
@ -10,7 +8,10 @@ class WorkflowRunState:
|
|||
variable_pool: VariablePool
|
||||
|
||||
total_tokens: int = 0
|
||||
total_price: Decimal = Decimal(0)
|
||||
currency: str = "USD"
|
||||
|
||||
workflow_node_executions: list[WorkflowNodeExecution] = []
|
||||
|
||||
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
|
||||
self.workflow_run = workflow_run
|
||||
self.start_at = start_at
|
||||
self.variable_pool = variable_pool
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class BaseNode:
|
||||
|
|
@ -13,17 +14,23 @@ class BaseNode:
|
|||
|
||||
node_id: str
|
||||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
def __init__(self, config: dict) -> None:
|
||||
stream_output_supported: bool = False
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
|
||||
def __init__(self, config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
self.node_id = config.get("id")
|
||||
if not self.node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_data = self._node_data_cls(**config.get("data", {}))
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> dict:
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
|
|
@ -33,22 +40,41 @@ class BaseNode:
|
|||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> dict:
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:param callbacks: callbacks
|
||||
:return:
|
||||
"""
|
||||
if variable_pool is None and run_args is None:
|
||||
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
|
||||
|
||||
return self._run(
|
||||
variable_pool=variable_pool,
|
||||
run_args=run_args
|
||||
)
|
||||
try:
|
||||
result = self._run(
|
||||
variable_pool=variable_pool,
|
||||
run_args=run_args
|
||||
)
|
||||
except Exception as e:
|
||||
# process unhandled exception
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
self.node_run_result = result
|
||||
return result
|
||||
|
||||
def publish_text_chunk(self, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
:param text: chunk text
|
||||
:return:
|
||||
"""
|
||||
if self.stream_output_supported:
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_text_chunk(text)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowRunState
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
|
|
@ -31,6 +32,7 @@ from models.workflow import (
|
|||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
|
|
@ -120,8 +122,7 @@ class WorkflowEngineManager:
|
|||
|
||||
return default_config
|
||||
|
||||
def run_workflow(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
def run_workflow(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
|
|
@ -129,7 +130,6 @@ class WorkflowEngineManager:
|
|||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
"""
|
||||
Run workflow
|
||||
:param app_model: App instance
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
|
|
@ -143,13 +143,23 @@ class WorkflowEngineManager:
|
|||
if not graph:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
if 'nodes' not in graph or 'edges' not in graph:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if isinstance(graph.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if isinstance(graph.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# init workflow run
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=workflow,
|
||||
triggered_from=triggered_from,
|
||||
user=user,
|
||||
user_inputs=user_inputs,
|
||||
system_inputs=system_inputs
|
||||
system_inputs=system_inputs,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
|
|
@ -161,44 +171,54 @@ class WorkflowEngineManager:
|
|||
)
|
||||
)
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_started(workflow_run)
|
||||
|
||||
# fetch start node
|
||||
start_node = self._get_entry_node(graph)
|
||||
if not start_node:
|
||||
self._workflow_run_failed(
|
||||
workflow_run_state=workflow_run_state,
|
||||
error='Start node not found in workflow graph',
|
||||
callbacks=callbacks
|
||||
)
|
||||
return
|
||||
# fetch predecessor node ids before end node (include: llm, direct answer)
|
||||
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
|
||||
|
||||
try:
|
||||
predecessor_node = None
|
||||
current_node = start_node
|
||||
while True:
|
||||
# run workflow
|
||||
self._run_workflow_node(
|
||||
workflow_run_state=workflow_run_state,
|
||||
node=current_node,
|
||||
# get next node, multiple target nodes in the future
|
||||
next_node = self._get_next_node(
|
||||
graph=graph,
|
||||
predecessor_node=predecessor_node,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
if current_node.node_type == NodeType.END:
|
||||
if not next_node:
|
||||
break
|
||||
|
||||
# todo fetch next node until end node finished or no next node
|
||||
current_node = None
|
||||
# check if node is streamable
|
||||
if next_node.node_id in streamable_node_ids:
|
||||
next_node.stream_output_supported = True
|
||||
|
||||
if not current_node:
|
||||
break
|
||||
# max steps 30 reached
|
||||
if len(workflow_run_state.workflow_node_executions) > 30:
|
||||
raise ValueError('Max steps 30 reached.')
|
||||
|
||||
predecessor_node = current_node
|
||||
# or max steps 30 reached
|
||||
# or max execution time 10min reached
|
||||
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=600):
|
||||
raise ValueError('Max execution time 10min reached.')
|
||||
|
||||
# run workflow, run multiple target nodes in the future
|
||||
self._run_workflow_node(
|
||||
workflow_run_state=workflow_run_state,
|
||||
node=next_node,
|
||||
predecessor_node=predecessor_node,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
if next_node.node_type == NodeType.END:
|
||||
break
|
||||
|
||||
predecessor_node = next_node
|
||||
|
||||
if not predecessor_node and not next_node:
|
||||
self._workflow_run_failed(
|
||||
workflow_run_state=workflow_run_state,
|
||||
error='Start node not found in workflow graph.',
|
||||
callbacks=callbacks
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
self._workflow_run_failed(
|
||||
workflow_run_state=workflow_run_state,
|
||||
|
|
@ -213,11 +233,40 @@ class WorkflowEngineManager:
|
|||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
|
||||
"""
|
||||
Fetch streamable node ids
|
||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||
|
||||
:param workflow: Workflow instance
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
workflow_type = WorkflowType.value_of(workflow.type)
|
||||
|
||||
streamable_node_ids = []
|
||||
end_node_ids = []
|
||||
for node_config in graph.get('nodes'):
|
||||
if node_config.get('type') == NodeType.END.value:
|
||||
if workflow_type == WorkflowType.WORKFLOW:
|
||||
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
else:
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
|
||||
for edge_config in graph.get('edges'):
|
||||
if edge_config.get('target') in end_node_ids:
|
||||
streamable_node_ids.append(edge_config.get('source'))
|
||||
|
||||
return streamable_node_ids
|
||||
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
system_inputs: Optional[dict] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
|
|
@ -225,6 +274,7 @@ class WorkflowEngineManager:
|
|||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
|
|
@ -260,6 +310,39 @@ class WorkflowEngineManager:
|
|||
db.session.rollback()
|
||||
raise
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_started(workflow_run)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_success(self, workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run_state: workflow run state
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
workflow_run = workflow_run_state.workflow_run
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
|
||||
# fetch last workflow_node_executions
|
||||
last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1]
|
||||
if last_workflow_node_execution:
|
||||
workflow_run.outputs = json.dumps(last_workflow_node_execution.node_run_result.outputs)
|
||||
|
||||
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
|
||||
workflow_run.total_tokens = workflow_run_state.total_tokens
|
||||
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
|
||||
workflow_run.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_finished(workflow_run)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_failed(self, workflow_run_state: WorkflowRunState,
|
||||
|
|
@ -277,9 +360,8 @@ class WorkflowEngineManager:
|
|||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
|
||||
workflow_run.total_tokens = workflow_run_state.total_tokens
|
||||
workflow_run.total_price = workflow_run_state.total_price
|
||||
workflow_run.currency = workflow_run_state.currency
|
||||
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
|
||||
workflow_run.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -289,21 +371,77 @@ class WorkflowEngineManager:
|
|||
|
||||
return workflow_run
|
||||
|
||||
def _get_entry_node(self, graph: dict) -> Optional[StartNode]:
|
||||
def _get_next_node(self, graph: dict,
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
|
||||
"""
|
||||
Get entry node
|
||||
Get next node
|
||||
multiple target nodes in the future.
|
||||
:param graph: workflow graph
|
||||
:param predecessor_node: predecessor node
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
for node_config in nodes.items():
|
||||
if node_config.get('type') == NodeType.START.value:
|
||||
return StartNode(config=node_config)
|
||||
if not predecessor_node:
|
||||
for node_config in nodes:
|
||||
if node_config.get('type') == NodeType.START.value:
|
||||
return StartNode(config=node_config)
|
||||
else:
|
||||
edges = graph.get('edges')
|
||||
source_node_id = predecessor_node.node_id
|
||||
|
||||
return None
|
||||
# fetch all outgoing edges from source node
|
||||
outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
|
||||
if not outgoing_edges:
|
||||
return None
|
||||
|
||||
# fetch target node id from outgoing edges
|
||||
outgoing_edge = None
|
||||
source_handle = predecessor_node.node_run_result.edge_source_handle
|
||||
if source_handle:
|
||||
for edge in outgoing_edges:
|
||||
if edge.get('source_handle') and edge.get('source_handle') == source_handle:
|
||||
outgoing_edge = edge
|
||||
break
|
||||
else:
|
||||
outgoing_edge = outgoing_edges[0]
|
||||
|
||||
if not outgoing_edge:
|
||||
return None
|
||||
|
||||
target_node_id = outgoing_edge.get('target')
|
||||
|
||||
# fetch target node from target node id
|
||||
target_node_config = None
|
||||
for node in nodes:
|
||||
if node.get('id') == target_node_id:
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
return None
|
||||
|
||||
# get next node
|
||||
target_node = node_classes.get(NodeType.value_of(target_node_config.get('type')))
|
||||
|
||||
return target_node(
|
||||
config=target_node_config,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
:param start_at: start time
|
||||
:param max_execution_time: max execution time
|
||||
:return:
|
||||
"""
|
||||
# TODO check queue is stopped
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
|
||||
node: BaseNode,
|
||||
|
|
@ -320,28 +458,41 @@ class WorkflowEngineManager:
|
|||
# add to workflow node executions
|
||||
workflow_run_state.workflow_node_executions.append(workflow_node_execution)
|
||||
|
||||
try:
|
||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||
node_run_result = node.run(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||
node_run_result = node.run(
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
# node run failed
|
||||
self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
error=str(e),
|
||||
start_at=start_at,
|
||||
error=node_run_result.error,
|
||||
callbacks=callbacks
|
||||
)
|
||||
raise
|
||||
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
|
||||
|
||||
# node run success
|
||||
self._workflow_node_execution_success(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=start_at,
|
||||
result=node_run_result,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
for variable_key, variable_value in node_run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
if node_run_result.metadata.get('total_tokens'):
|
||||
workflow_run_state.total_tokens += int(node_run_result.metadata.get('total_tokens'))
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState,
|
||||
|
|
@ -384,3 +535,86 @@ class WorkflowEngineManager:
|
|||
callback.on_workflow_node_execute_started(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
result: NodeRunResult,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution success
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param result: node run result
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.inputs = json.dumps(result.inputs)
|
||||
workflow_node_execution.process_data = json.dumps(result.process_data)
|
||||
workflow_node_execution.outputs = json.dumps(result.outputs)
|
||||
workflow_node_execution.execution_metadata = json.dumps(result.metadata)
|
||||
workflow_node_execution.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_node_execute_finished(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
error: str,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param error: error message
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_node_execute_finished(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _append_variables_recursively(self, variable_pool: VariablePool,
|
||||
node_id: str,
|
||||
variable_key_list: list[str],
|
||||
variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param variable_pool: variable pool
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.append_variable(
|
||||
node_id=node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=variable_value
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool,
|
||||
node_id=node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ workflow_run_for_log_fields = {
|
|||
"error": fields.String,
|
||||
"elapsed_time": fields.Float,
|
||||
"total_tokens": fields.Integer,
|
||||
"total_price": fields.Float,
|
||||
"currency": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": TimestampField
|
||||
|
|
@ -29,8 +27,6 @@ workflow_run_for_list_fields = {
|
|||
"error": fields.String,
|
||||
"elapsed_time": fields.Float,
|
||||
"total_tokens": fields.Integer,
|
||||
"total_price": fields.Float,
|
||||
"currency": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
|
|
@ -56,8 +52,6 @@ workflow_run_detail_fields = {
|
|||
"error": fields.String,
|
||||
"elapsed_time": fields.Float,
|
||||
"total_tokens": fields.Integer,
|
||||
"total_price": fields.Float,
|
||||
"currency": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"created_by_role": fields.String,
|
||||
"created_by_account": fields.Nested(simple_account_fields, attribute='created_by_account', allow_null=True),
|
||||
|
|
|
|||
|
|
@ -216,8 +216,6 @@ class WorkflowRun(db.Model):
|
|||
- error (string) `optional` Error reason
|
||||
- elapsed_time (float) `optional` Time consumption (s)
|
||||
- total_tokens (int) `optional` Total tokens used
|
||||
- total_price (decimal) `optional` Total cost
|
||||
- currency (string) `optional` Currency, such as USD / RMB
|
||||
- total_steps (int) Total steps (redundant), default 0
|
||||
- created_by_role (string) Creator role
|
||||
|
||||
|
|
@ -251,8 +249,6 @@ class WorkflowRun(db.Model):
|
|||
error = db.Column(db.Text)
|
||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0'))
|
||||
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
|
||||
total_price = db.Column(db.Numeric(10, 7))
|
||||
currency = db.Column(db.String(255))
|
||||
total_steps = db.Column(db.Integer, server_default=db.text('0'))
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
|
|
|
|||
Loading…
Reference in New Issue