From 1da5862a96fb53a036c848d7e2f413efa01223d0 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 15 Aug 2024 03:12:49 +0800 Subject: [PATCH] feat(workflow): fix iteration single debug --- api/core/app/apps/advanced_chat/app_runner.py | 147 ++++++------ .../advanced_chat/generate_task_pipeline.py | 15 +- api/core/app/apps/workflow/app_runner.py | 58 ++--- .../apps/workflow/generate_task_pipeline.py | 15 +- api/core/app/apps/workflow_app_runner.py | 124 ++++++++++- .../task_pipeline/workflow_cycle_manage.py | 16 +- api/core/workflow/errors.py | 1 - .../workflow/graph_engine/graph_engine.py | 134 +++++------ api/core/workflow/nodes/answer/answer_node.py | 5 +- api/core/workflow/nodes/code/code_node.py | 3 +- api/core/workflow/nodes/end/end_node.py | 4 +- .../nodes/http_request/http_request_node.py | 4 +- .../workflow/nodes/if_else/if_else_node.py | 4 +- .../nodes/iteration/iteration_node.py | 1 - .../knowledge_retrieval_node.py | 4 +- api/core/workflow/nodes/llm/llm_node.py | 5 +- .../parameter_extractor_node.py | 3 +- .../question_classifier_node.py | 4 +- api/core/workflow/nodes/start/start_node.py | 5 +- .../template_transform_node.py | 3 +- .../variable_aggregator_node.py | 4 +- api/core/workflow/workflow_entry.py | 209 +++--------------- 22 files changed, 378 insertions(+), 390 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index be04ed29b0..1483a80edb 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -7,7 +7,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( @@ -15,7 +15,6 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( - AppQueueEvent, QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent, @@ -84,86 +83,84 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) - # if only single iteration run is requested if self.application_generate_entity.single_iteration_run: - node_id = self.application_generate_entity.single_iteration_run.node_id - user_inputs = self.application_generate_entity.single_iteration_run.inputs - - generator = WorkflowEntry.single_step_run_iteration( + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, - node_id=node_id, - user_id=self.application_generate_entity.user_id, - user_inputs=user_inputs, - callbacks=workflow_callbacks + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs + ) + else: + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files + + # moderation + if self.handle_input_moderation( + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity + ): + return + + db.session.close() + + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + conversation_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + session.commit() + # Convert database entities to variables + conversation_variables = [item.to_variable() for item in conversation_variables] + + # Create a variable pool. + system_inputs = { + SystemVariable.QUERY: query, + SystemVariable.FILES: files, + SystemVariable.CONVERSATION_ID: self.conversation.id, + SystemVariable.USER_ID: user_id, + } + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, ) - for event in generator: - # TODO - self._handle_event(workflow_entry, event) - return - - inputs = self.application_generate_entity.inputs - query = self.application_generate_entity.query - files = self.application_generate_entity.files - - # moderation - if self.handle_input_moderation( - app_record=app_record, - app_generate_entity=self.application_generate_entity, - inputs=inputs, - query=query, - message_id=self.message.id - ): - return - - # annotation reply - if self.handle_annotation_reply( - app_record=app_record, - message=self.message, - query=query, - app_generate_entity=self.application_generate_entity - ): - return - - db.session.close() - - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - conversation_variables = [ - ConversationVariable.from_variable( - app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - session.commit() - # Convert database entities to variables - conversation_variables = [item.to_variable() for item in conversation_variables] - - # Create a variable pool. - system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: self.conversation.id, - SystemVariable.USER_ID: user_id, - } - - # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) # RUN WORKFLOW workflow_entry = WorkflowEntry( - workflow=workflow, + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=workflow.type, + graph=graph, + graph_config=workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 89d9a4deb9..8c7860f8b4 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -276,25 +276,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc event=event ) - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if response: + yield response elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if response: + yield response elif isinstance(event, QueueNodeFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if response: + yield response elif isinstance(event, QueueIterationStartEvent): if not workflow_run: raise Exception('Workflow run not initialized.') diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 0175599938..6cc2a74bbc 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -71,41 +71,41 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): # if only single iteration run is requested if self.application_generate_entity.single_iteration_run: - node_id = self.application_generate_entity.single_iteration_run.node_id - user_inputs = self.application_generate_entity.single_iteration_run.inputs - - generator = WorkflowEntry.single_step_run_iteration( + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( workflow=workflow, - node_id=node_id, - user_id=self.application_generate_entity.user_id, - user_inputs=user_inputs, - callbacks=workflow_callbacks + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs + ) + else: + + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariable.FILES: files, + SystemVariable.USER_ID: user_id, + } + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], ) - for event in generator: - # TODO - self._handle_event(workflow_entry, event) - return - - inputs = self.application_generate_entity.inputs - files = self.application_generate_entity.files - - # Create a variable pool. - system_inputs = { - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id, - } - - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - ) + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) # RUN WORKFLOW workflow_entry = WorkflowEntry( - workflow=workflow, + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=workflow.type, + graph=graph, + graph_config=workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 6955844af5..08a335ec36 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -249,25 +249,34 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa event=event ) - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if response: + yield response elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if response: + yield response elif isinstance(event, QueueNodeFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + + if response: + yield response elif isinstance(event, QueueIterationStartEvent): if not workflow_run: raise Exception('Workflow run not initialized.') diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 212fde82f8..42a2f60582 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Mapping, Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner @@ -18,6 +18,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, @@ -36,6 +38,10 @@ from core.workflow.graph_engine.entities.event import ( ParallelBranchRunStartedEvent, ParallelBranchRunSucceededEvent, ) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.node_mapping import node_classes from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App @@ -46,6 +52,122 @@ class WorkflowBasedAppRunner(AppRunner): def __init__(self, queue_manager: AppQueueManager): self.queue_manager = queue_manager + def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: + """ + Init graph + """ + if 'nodes' not in graph_config or 'edges' not in graph_config: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph_config.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph_config.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + # init graph + graph = Graph.init( + graph_config=graph_config + ) + + if not graph: + raise ValueError('graph not found in workflow') + + return graph + + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + # fetch workflow graph + graph_config = workflow.graph_dict + if not graph_config: + raise ValueError('workflow graph not found') + + graph_config = cast(dict[str, Any], graph_config) + + if 'nodes' not in graph_config or 'edges' not in graph_config: + raise ValueError('nodes or edges not found in workflow graph') + + if not isinstance(graph_config.get('nodes'), list): + raise ValueError('nodes in workflow graph must be a list') + + if not isinstance(graph_config.get('edges'), list): + raise ValueError('edges in workflow graph must be a list') + + # filter nodes only in iteration + node_configs = [ + node for node in graph_config.get('nodes', []) + if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id + ] + + graph_config['nodes'] = node_configs + + node_ids = [node.get('id') for node in node_configs] + + # filter edges only in iteration + edge_configs = [ + edge for edge in graph_config.get('edges', []) + if (edge.get('source') is None or edge.get('source') in node_ids) + and (edge.get('target') is None or edge.get('target') in node_ids) + ] + + graph_config['edges'] = edge_configs + + # init graph + graph = Graph.init( + graph_config=graph_config, + root_node_id=node_id + ) + + if not graph: + raise ValueError('graph not found in workflow') + + # fetch node config from node id + iteration_node_config = None + for node in node_configs: + if node.get('id') == node_id: + iteration_node_config = node + break + + if not iteration_node_config: + raise ValueError('iteration node id not found in workflow graph') + + # Get node class + node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) + node_cls = node_classes.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, + config=iteration_node_config + ) + except NotImplementedError: + variable_mapping = {} + + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=IterationNodeData(**iteration_node_config.get('data', {})) + ) + + return graph, variable_pool + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: """ Handle event diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index ad3aa85b7f..15a9833a66 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -231,7 +231,6 @@ class WorkflowCycleManage: outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - event.start_at.timestamp() workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None @@ -239,6 +238,7 @@ class WorkflowCycleManage: json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -259,11 +259,11 @@ class WorkflowCycleManage: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = event.error - workflow_node_execution.elapsed_time = time.perf_counter() - event.start_at.timestamp() workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds() db.session.commit() db.session.refresh(workflow_node_execution) @@ -344,7 +344,7 @@ class WorkflowCycleManage: def _workflow_node_start_to_stream_response( self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution - ) -> NodeStartStreamResponse: + ) -> Optional[NodeStartStreamResponse]: """ Workflow node start to stream response. :param event: queue node started event @@ -352,6 +352,9 @@ class WorkflowCycleManage: :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + response = NodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -380,13 +383,16 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, task_id: str, workflow_node_execution: WorkflowNodeExecution - ) -> NodeFinishStreamResponse: + ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. :param task_id: task id :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + return None + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -483,7 +489,7 @@ class WorkflowCycleManage: inputs=event.inputs or {}, status=WorkflowNodeExecutionStatus.SUCCEEDED, error=None, - elapsed_time=time.perf_counter() - event.start_at.timestamp(), + elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 80b2501da2..07cbcd981e 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,4 +1,3 @@ -from core.workflow.entities.node_entities import NodeType from core.workflow.nodes.base_node import BaseNode diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 5a0d8a10e2..b9e3e78e5c 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -39,7 +39,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor -from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import node_classes from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -400,81 +400,85 @@ class GraphEngine: # run node generator = node_instance.run() for item in generator: - if isinstance(item, RunCompletedEvent): - run_result = item.run_result - route_node_state.set_finished(run_result=run_result) + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id - if run_result.status == WorkflowNodeExecutionStatus.FAILED: - yield NodeRunFailedEvent( - error=route_node_state.failed_reason or 'Unknown error.', - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id - ) - elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - # plus state total_tokens - self.graph_runtime_state.total_tokens += int( - run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + route_node_state.set_finished(run_result=run_result) + + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or 'Unknown error.', + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id ) - - if run_result.llm_usage: - # use the latest usage - self.graph_runtime_state.llm_usage += run_result.llm_usage - - # append node output variables to variable pool - if run_result.outputs: - for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node_instance.node_id, - variable_key_list=[variable_key], - variable_value=variable_value + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] ) - yield NodeRunSucceededEvent( + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id + ) + + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, route_node_state=route_node_state, parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, + ) + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, ) - - break - elif isinstance(item, RunStreamChunkEvent): - yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - chunk_content=item.chunk_content, - from_variable_selector=item.from_variable_selector, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - ) - elif isinstance(item, RunRetrieverResourceEvent): - yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - retriever_resources=item.retriever_resources, - context=item.context, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - ) - elif isinstance(item, BaseIterationEvent): - # add parallel info to iteration event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id except GenerateTaskStoppedException: # trigger node run failed event route_node_state.status = RouteNodeState.Status.FAILED diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index d2311b04e9..8cf01727ec 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,8 +1,7 @@ -from typing import Any, Mapping, Sequence, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 7c066ad083..f0d8ffbca4 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional, Sequence, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 8299f4d9f2..552914b308 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,6 +1,6 @@ -from typing import Any, Mapping, Sequence, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.end.entities import EndNodeData diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 6a94a6bd32..c69394a891 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,12 +1,12 @@ import logging +from collections.abc import Mapping, Sequence from mimetypes import guess_extension from os import path -from typing import Any, Mapping, Sequence, cast +from typing import Any, cast from core.app.segments import parser from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.http_request.entities import ( diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index feb0175a74..ca87eecd0d 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,6 +1,6 @@ -from typing import Any, Mapping, Sequence, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.if_else.entities import IfElseNodeData diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index f7904aa836..7f6990604a 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -5,7 +5,6 @@ from typing import Any, cast from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1e9ff9ff79..b2991b624b 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Mapping, Sequence, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast from sqlalchemy import func @@ -13,7 +14,6 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrival_methods import RetrievalMethod -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 5fdf2456df..b2f4f3ad4b 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,7 +1,7 @@ import json -from collections.abc import Generator +from collections.abc import Generator, Mapping, Sequence from copy import deepcopy -from typing import Any, Mapping, Optional, Sequence, cast +from typing import Any, Optional, cast from pydantic import BaseModel @@ -24,7 +24,6 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import InNodeEvent diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index f4ff251ead..2e65705f10 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,6 +1,7 @@ import json import uuid -from typing import Any, Mapping, Optional, Sequence, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 97996872d9..777ff468f0 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,7 @@ import json import logging -from typing import Any, Mapping, Optional, Sequence, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -14,7 +15,6 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 826c3526e6..10131bd6a5 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,6 +1,7 @@ -from typing import Any, Mapping, Sequence -from core.workflow.entities.base_node_data_entities import BaseNodeData +from collections.abc import Mapping, Sequence +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 4a19792c64..b14a394a0a 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,5 +1,6 @@ import os -from typing import Any, Mapping, Optional, Sequence, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 186bbce2af..6944d9e82d 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,6 +1,6 @@ -from typing import Any, Mapping, Sequence, cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 8ab5d27eb2..d681174716 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, Type, cast +from typing import Any, Optional, cast from configs import dify_config from core.app.app_config.entities import FileExtraConfig @@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom +from core.workflow.entities.node_entities import NodeType, UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent @@ -20,8 +20,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.event import RunCompletedEvent, RunEvent -from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.event import RunEvent from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.node_mapping import node_classes from models.workflow import ( @@ -35,7 +34,12 @@ logger = logging.getLogger(__name__) class WorkflowEntry: def __init__( self, - workflow: Workflow, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, @@ -43,46 +47,29 @@ class WorkflowEntry: variable_pool: VariablePool ) -> None: """ - :param workflow: Workflow instance + Init workflow entry + :param tenant_id: tenant id + :param app_id: app id + :param workflow_id: workflow id + :param workflow_type: workflow type + :param graph_config: workflow graph config + :param graph: workflow graph :param user_id: user id :param user_from: user from - :param invoke_from: invoke from service-api, web-app, debugger, explore + :param invoke_from: invoke from :param call_depth: call depth :param variable_pool: variable pool - :param single_step_run_iteration_id: single step run iteration id """ - # fetch workflow graph - graph_config = workflow.graph_dict - if not graph_config: - raise ValueError('workflow graph not found') - - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') - - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') - - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') - workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) - # init graph - graph = Graph.init( - graph_config=graph_config - ) - - if not graph: - raise ValueError('graph not found in workflow') - # init workflow run state self.graph_engine = GraphEngine( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_type=WorkflowType.value_of(workflow.type), - workflow_id=workflow.id, + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, user_id=user_id, user_from=user_from, invoke_from=invoke_from, @@ -133,152 +120,6 @@ class WorkflowEntry: ) return - @classmethod - def single_step_run_iteration( - cls, - workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: Sequence[WorkflowCallback], - ) -> Generator[GraphEngineEvent, None, None]: - """ - Single step run workflow node iteration - :param workflow: Workflow instance - :param node_id: node id - :param user_id: user id - :param user_inputs: user inputs - :return: - """ - # fetch workflow graph - graph_config = workflow.graph_dict - if not graph_config: - raise ValueError('workflow graph not found') - - graph_config = cast(dict[str, Any], graph_config) - - if 'nodes' not in graph_config or 'edges' not in graph_config: - raise ValueError('nodes or edges not found in workflow graph') - - if not isinstance(graph_config.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') - - if not isinstance(graph_config.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') - - # filter nodes only in iteration - node_configs = [ - node for node in graph_config.get('nodes', []) - if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id - ] - - graph_config['nodes'] = node_configs - - node_ids = [node.get('id') for node in node_configs] - - # filter edges only in iteration - edge_configs = [ - edge for edge in graph_config.get('edges', []) - if (edge.get('source') is None or edge.get('source') in node_ids) - and (edge.get('target') is None or edge.get('target') in node_ids) - ] - - graph_config['edges'] = edge_configs - - # init graph - graph = Graph.init( - graph_config=graph_config, - root_node_id=node_id - ) - - if not graph: - raise ValueError('graph not found in workflow') - - # fetch node config from node id - iteration_node_config = None - for node in node_configs: - if node.get('id') == node_id: - iteration_node_config = node - break - - if not iteration_node_config: - raise ValueError('iteration node id not found in workflow graph') - - # Get node class - node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type')) - node_cls = node_classes.get(node_type) - node_cls = cast(type[BaseNode], node_cls) - - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - ) - - try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, - config=iteration_node_config - ) - except NotImplementedError: - variable_mapping = {} - - cls._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_type=node_type, - node_data=IterationNodeData(**iteration_node_config.get('data', {})) - ) - - # init workflow run state - graph_engine = GraphEngine( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_type=WorkflowType.value_of(workflow.type), - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=1, - graph=graph, - graph_config=graph_config, - variable_pool=variable_pool, - max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME - ) - - try: - # run workflow - generator = graph_engine.run() - for event in generator: - if callbacks: - for callback in callbacks: - callback.on_event( - graph=graph_engine.graph, - graph_init_params=graph_engine.init_params, - graph_runtime_state=graph_engine.graph_runtime_state, - event=event - ) - yield event - except GenerateTaskStoppedException: - pass - except Exception as e: - logger.exception("Unknown Error when workflow entry running") - if callbacks: - for callback in callbacks: - callback.on_event( - graph=graph_engine.graph, - graph_init_params=graph_engine.init_params, - graph_runtime_state=graph_engine.graph_runtime_state, - event=GraphRunFailedEvent( - error=str(e) - ) - ) - return - @classmethod def single_step_run( cls, @@ -366,7 +207,7 @@ class WorkflowEntry: except NotImplementedError: variable_mapping = {} - cls._mapping_user_inputs_to_variable_pool( + cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, variable_pool=variable_pool, @@ -413,7 +254,7 @@ class WorkflowEntry: return new_value @classmethod - def _mapping_user_inputs_to_variable_pool( + def mapping_user_inputs_to_variable_pool( cls, variable_mapping: Mapping[str, Sequence[str]], user_inputs: dict, @@ -428,11 +269,11 @@ class WorkflowEntry: if len(node_variable_list) < 1: raise ValueError(f'Invalid node variable {node_variable}') - node_variable_key = node_variable_list[1:] + node_variable_key = '.'.join(node_variable_list[1:]) if ( node_variable_key not in user_inputs - or node_variable not in user_inputs + and node_variable not in user_inputs ) and not variable_pool.get(variable_selector): raise ValueError(f'Variable key {node_variable} not found in user inputs.')