mirror of https://github.com/langgenius/dify.git
feat(workflow): fix iteration single debug
This commit is contained in:
parent
6f6b32e1ee
commit
1da5862a96
|
|
@ -7,7 +7,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
|
|
@ -15,7 +15,6 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
|
|
@ -84,86 +83,84 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
node_id = self.application_generate_entity.single_iteration_run.node_id
|
||||
user_inputs = self.application_generate_entity.single_iteration_run.inputs
|
||||
|
||||
generator = WorkflowEntry.single_step_run_iteration(
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_inputs=user_inputs,
|
||||
callbacks=workflow_callbacks
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity
|
||||
):
|
||||
return
|
||||
|
||||
db.session.close()
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.commit()
|
||||
# Convert database entities to variables
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
# TODO
|
||||
self._handle_event(workflow_entry, event)
|
||||
return
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity
|
||||
):
|
||||
return
|
||||
|
||||
db.session.close()
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
session.commit()
|
||||
# Convert database entities to variables
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
workflow=workflow,
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=workflow.type,
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
|
|
|
|||
|
|
@ -276,25 +276,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
event=event
|
||||
)
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
|
|
|||
|
|
@ -71,41 +71,41 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
node_id = self.application_generate_entity.single_iteration_run.node_id
|
||||
user_inputs = self.application_generate_entity.single_iteration_run.inputs
|
||||
|
||||
generator = WorkflowEntry.single_step_run_iteration(
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_inputs=user_inputs,
|
||||
callbacks=workflow_callbacks
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
# TODO
|
||||
self._handle_event(workflow_entry, event)
|
||||
return
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
workflow=workflow,
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=workflow.type,
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
|
|
|
|||
|
|
@ -249,25 +249,34 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
event=event
|
||||
)
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Any, Mapping, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
|
|
@ -18,6 +18,8 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
|
|
@ -36,6 +38,10 @@ from core.workflow.graph_engine.entities.event import (
|
|||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
|
|
@ -46,6 +52,122 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
|
|
|
|||
|
|
@ -231,7 +231,6 @@ class WorkflowCycleManage:
|
|||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - event.start_at.timestamp()
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
|
|
@ -239,6 +238,7 @@ class WorkflowCycleManage:
|
|||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
|
@ -259,11 +259,11 @@ class WorkflowCycleManage:
|
|||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - event.start_at.timestamp()
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
|
@ -344,7 +344,7 @@ class WorkflowCycleManage:
|
|||
|
||||
def _workflow_node_start_to_stream_response(
|
||||
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
|
||||
) -> NodeStartStreamResponse:
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
"""
|
||||
Workflow node start to stream response.
|
||||
:param event: queue node started event
|
||||
|
|
@ -352,6 +352,9 @@ class WorkflowCycleManage:
|
|||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
|
|
@ -380,13 +383,16 @@ class WorkflowCycleManage:
|
|||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self, task_id: str, workflow_node_execution: WorkflowNodeExecution
|
||||
) -> NodeFinishStreamResponse:
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
|
|
@ -483,7 +489,7 @@ class WorkflowCycleManage:
|
|||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - event.start_at.timestamp(),
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
|||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
|
@ -400,81 +400,85 @@ class GraphEngine:
|
|||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or 'Unknown error.',
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or 'Unknown error.',
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Any, Mapping, Sequence, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, Mapping, Optional, Sequence, Union, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Mapping, Sequence, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.segments import parser
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Mapping, Sequence, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import Any, cast
|
|||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
|
|
@ -13,7 +14,6 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Any, Mapping, Optional, Sequence, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -24,7 +24,6 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import uuid
|
||||
from typing import Any, Mapping, Optional, Sequence, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Mapping, Optional, Sequence, Union, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -14,7 +15,6 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
|||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
from typing import Any, Mapping, Sequence
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from typing import Any, Mapping, Optional, Sequence, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Mapping, Sequence, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, Type, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
|
|
@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeType, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
|
||||
|
|
@ -20,8 +20,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
|||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.event import RunEvent
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from models.workflow import (
|
||||
|
|
@ -35,7 +34,12 @@ logger = logging.getLogger(__name__)
|
|||
class WorkflowEntry:
|
||||
def __init__(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph: Graph,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
|
|
@ -43,46 +47,29 @@ class WorkflowEntry:
|
|||
variable_pool: VariablePool
|
||||
) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
Init workflow entry
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param workflow_id: workflow id
|
||||
:param workflow_type: workflow type
|
||||
:param graph_config: workflow graph config
|
||||
:param graph: workflow graph
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param invoke_from: invoke from service-api, web-app, debugger, explore
|
||||
:param invoke_from: invoke from
|
||||
:param call_depth: call depth
|
||||
:param variable_pool: variable pool
|
||||
:param single_step_run_iteration_id: single step run iteration id
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# init workflow run state
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
workflow_id=workflow.id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
|
|
@ -133,152 +120,6 @@ class WorkflowEntry:
|
|||
)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def single_step_run_iteration(
|
||||
cls,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Single step run workflow node iteration
|
||||
:param workflow: Workflow instance
|
||||
:param node_id: node id
|
||||
:param user_id: user id
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls._mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
workflow_id=workflow.id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=1,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
|
||||
try:
|
||||
# run workflow
|
||||
generator = graph_engine.run()
|
||||
for event in generator:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
yield event
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when workflow entry running")
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=GraphRunFailedEvent(
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def single_step_run(
|
||||
cls,
|
||||
|
|
@ -366,7 +207,7 @@ class WorkflowEntry:
|
|||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls._mapping_user_inputs_to_variable_pool(
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
|
|
@ -413,7 +254,7 @@ class WorkflowEntry:
|
|||
return new_value
|
||||
|
||||
@classmethod
|
||||
def _mapping_user_inputs_to_variable_pool(
|
||||
def mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
|
|
@ -428,11 +269,11 @@ class WorkflowEntry:
|
|||
if len(node_variable_list) < 1:
|
||||
raise ValueError(f'Invalid node variable {node_variable}')
|
||||
|
||||
node_variable_key = node_variable_list[1:]
|
||||
node_variable_key = '.'.join(node_variable_list[1:])
|
||||
|
||||
if (
|
||||
node_variable_key not in user_inputs
|
||||
or node_variable not in user_inputs
|
||||
and node_variable not in user_inputs
|
||||
) and not variable_pool.get(variable_selector):
|
||||
raise ValueError(f'Variable key {node_variable} not found in user inputs.')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue