diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index d1ee8bf166..58e6248d12 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Any, Optional, Union from flask import Flask, current_app from pydantic import ValidationError @@ -39,7 +39,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): args: dict, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -67,7 +67,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # get conversation conversation = None if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation = self._get_conversation_by_user(app_model, args.get('conversation_id', ''), user) # parse files files = args['files'] if args.get('files') else [] @@ -126,9 +126,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation = None, + conversation: Optional[Conversation] = None, stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + -> dict[str, Any] | Generator[str, Any, None]: is_first_conversation = False if not conversation: is_first_conversation = True @@ -157,7 +157,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # new thread worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, 'conversation_id': conversation.id, @@ -209,13 +209,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message = self._get_message(message_id) # chatbot app - runner = AdvancedChatAppRunner() - runner.run( + runner = AdvancedChatAppRunner( application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, message=message ) + + runner.run() except GenerateTaskStoppedException: pass except InvokeAuthorizationError: @@ -227,7 +228,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true': logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 42fdb750ab..14fc5d993e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,5 @@ import logging import os -import time from collections.abc import Mapping from typing import Any, Optional, cast @@ -12,10 +11,45 @@ from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, ) -from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueAnnotationReplyEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) from core.moderation.base import ModerationException from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import SystemVariable, UserFrom +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, Conversation, EndUser, Message @@ -29,19 +63,30 @@ class AdvancedChatAppRunner(AppRunner): AdvancedChat Application Runner """ - def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, + def __init__( + self, + application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, - message: Message) -> None: + message: Message + ) -> None: """ - Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :param conversation: conversation :param message: message + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.conversation = conversation + self.message = message + + def run(self) -> None: + """ + Run application :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) app_record = db.session.query(App).filter(App.id == app_config.app_id).first() @@ -52,36 +97,34 @@ class AdvancedChatAppRunner(AppRunner): if not workflow: raise ValueError("Workflow not initialized") - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: - user_id = application_generate_entity.user_id + user_id = self.application_generate_entity.user_id # moderation if self.handle_input_moderation( - queue_manager=queue_manager, app_record=app_record, - app_generate_entity=application_generate_entity, + app_generate_entity=self.application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=self.message.id ): return # annotation reply if self.handle_annotation_reply( app_record=app_record, - message=message, + message=self.message, query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity + app_generate_entity=self.application_generate_entity ): return @@ -92,25 +135,189 @@ class AdvancedChatAppRunner(AppRunner): workflow_callbacks.append(WorkflowLoggingCallback()) # RUN WORKFLOW - workflow_entry = WorkflowEntry() - workflow_entry.run( + workflow_entry = WorkflowEntry( workflow=workflow, - user_id=application_generate_entity.user_id, + user_id=self.application_generate_entity.user_id, user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, + invoke_from=self.application_generate_entity.invoke_from, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, + SystemVariable.CONVERSATION_ID: self.conversation.id, SystemVariable.USER_ID: user_id }, - call_depth=application_generate_entity.call_depth + call_depth=self.application_generate_entity.call_depth ) + generator = workflow_entry.run( + callbacks=workflow_callbacks, + ) + + for event in generator: + self._handle_event(workflow_entry, event) + + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + """ + Handle event + :param workflow_entry: workflow entry + :param event: event + """ + if isinstance(event, GraphRunStartedEvent): + self._publish_event( + QueueWorkflowStartedEvent( + graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state + ) + ) + elif isinstance(event, GraphRunSucceededEvent): + self._publish_event( + QueueWorkflowSucceededEvent(outputs=event.outputs) + ) + elif isinstance(event, GraphRunFailedEvent): + self._publish_event( + QueueWorkflowFailedEvent(error=event.error) + ) + elif isinstance(event, NodeRunStartedEvent): + self._publish_event( + QueueNodeStartedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.route_node_state.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + predecessor_node_id=event.predecessor_node_id + ) + ) + elif isinstance(event, NodeRunSucceededEvent): + self._publish_event( + QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result else {}, + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result + and event.route_node_state.node_run_result.error + else "Unknown error" + ) + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self._publish_event( + QueueTextChunkEvent( + text=event.chunk_content + ) + ) + elif isinstance(event, NodeRunRetrieverResourceEvent): + self._publish_event( + QueueRetrieverResourcesEvent( + retriever_resources=event.retriever_resources + ) + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id + ) + ) + elif isinstance(event, ParallelBranchRunSucceededEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id + ) + ) + elif isinstance(event, ParallelBranchRunFailedEvent): + self._publish_event( + QueueParallelBranchRunFailedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + error=event.error + ) + ) + elif isinstance(event, IterationRunStartedEvent): + self._publish_event( + QueueIterationStartEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id, + metadata=event.metadata + ) + ) + elif isinstance(event, IterationRunNextEvent): + self._publish_event( + QueueIterationNextEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + index=event.index, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + output=event.pre_iteration_output, + ) + ) + elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): + self._publish_event( + QueueIterationCompletedEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + outputs=event.outputs, + metadata=event.metadata, + steps=event.steps, + error=event.error if isinstance(event, IterationRunFailedEvent) else None + ) + ) + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: """ Get workflow @@ -126,7 +333,7 @@ class AdvancedChatAppRunner(AppRunner): return workflow def handle_input_moderation( - self, queue_manager: AppQueueManager, + self, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, inputs: Mapping[str, Any], @@ -135,7 +342,6 @@ class AdvancedChatAppRunner(AppRunner): ) -> bool: """ Handle input moderation - :param queue_manager: application queue manager :param app_record: app record :param app_generate_entity: application generate entity :param inputs: inputs @@ -154,10 +360,8 @@ class AdvancedChatAppRunner(AppRunner): message_id=message_id, ) except ModerationException as e: - self._stream_output( - queue_manager=queue_manager, + self._complete_with_stream_output( text=str(e), - stream=app_generate_entity.stream, stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION ) return True @@ -167,14 +371,12 @@ class AdvancedChatAppRunner(AppRunner): def handle_annotation_reply(self, app_record: App, message: Message, query: str, - queue_manager: AppQueueManager, app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: """ Handle annotation reply :param app_record: app record :param message: message :param query: query - :param queue_manager: application queue manager :param app_generate_entity: application generate entity """ # annotation reply @@ -187,50 +389,38 @@ class AdvancedChatAppRunner(AppRunner): ) if annotation_reply: - queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + self._publish_event( + QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id) ) - self._stream_output( - queue_manager=queue_manager, + self._complete_with_stream_output( text=annotation_reply.content, - stream=app_generate_entity.stream, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _stream_output(self, queue_manager: AppQueueManager, - text: str, - stream: bool, - stopped_by: QueueStopEvent.StopBy) -> None: + def _complete_with_stream_output(self, + text: str, + stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output - :param queue_manager: application queue manager :param text: text - :param stream: stream :return: """ - if stream: - index = 0 - for token in text: - queue_manager.publish( - QueueTextChunkEvent( - text=token - ), PublishFrom.APPLICATION_MANAGER - ) - index += 1 - time.sleep(0.01) - else: - queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER + self._publish_event( + QueueTextChunkEvent( + text=text ) + ) - queue_manager.publish( - QueueStopEvent(stopped_by=stopped_by), + self._publish_event( + QueueStopEvent(stopped_by=stopped_by) + ) + + def _publish_event(self, event: AppQueueEvent) -> None: + self.queue_manager.publish( + event, PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index b4ff94f59d..63c1d1de7b 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -30,7 +30,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, @@ -38,14 +37,15 @@ from core.app.entities.task_entities import ( MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, + WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage -from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account @@ -62,15 +62,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: AdvancedChatTaskState + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariable, Any] - _iteration_nested_relations: dict[str, list[str]] def __init__( - self, application_generate_entity: AdvancedChatAppGenerateEntity, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow, queue_manager: AppQueueManager, conversation: Conversation, @@ -105,12 +105,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc SystemVariable.USER_ID: user_id } - self._task_state = AdvancedChatTaskState( - usage=LLMUsage.empty_usage() - ) + self._task_state = WorkflowTaskState() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) - self._stream_generate_routes = self._get_stream_generate_routes() self._conversation_name_generate_thread = None def process(self): @@ -131,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc generator = self._wrapper_process_stream_response( trace_manager=self._application_generate_entity.trace_manager ) + if self._stream: return self._to_stream_response(generator) else: @@ -190,17 +187,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ Generator[StreamResponse, None, None]: - publisher = None + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -211,9 +209,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # timeout while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.checkAndGetAudio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -231,26 +229,37 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ + # init fake graph runtime state + graph_runtime_state = None + workflow_run = None + for message in self._queue_manager.listen(): - if publisher: - publisher.publish(message=message) + if tts_publisher: + tts_publisher.publish(message=message) + event = message.event - if isinstance(event, QueueErrorEvent): + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event, self._message) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + # init workflow run + workflow_run = self._handle_workflow_run_start() + + self._refetch_message() self._message.workflow_run_id = workflow_run.id db.session.commit() @@ -262,76 +271,158 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception('Workflow run not initialized.') + + workflow_node_execution = self._handle_node_execution_start( + workflow_run=workflow_run, + event=event + ) yield self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) yield self._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) + elif isinstance(event, QueueNodeFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, conversation_id=self._conversation.id, trace_manager=trace_manager + yield self._workflow_node_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=json.dumps(event.outputs) if event.outputs else None, + conversation_id=self._conversation.id, + trace_manager=trace_manager, ) - if workflow_run: - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run - ) - if workflow_run.status == WorkflowRunStatus.FAILED.value: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) - if isinstance(event, QueueStopEvent): - # Save message - self._save_message() + self._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), + PublishFrom.TASK_PIPELINE + ) + elif isinstance(event, QueueWorkflowFailedEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) - yield self._message_end_to_stream_response() - break - else: - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) + + err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) + yield self._error_to_stream_response(self._handle_error(err_event, self._message)) + break + elif isinstance(event, QueueStopEvent): + if not workflow_run: + raise Exception('Workflow run not initialized.') + + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error='Workflow stopped.', + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run + ) # Save message - self._save_message() + self._save_message(graph_runtime_state=graph_runtime_state) yield self._message_end_to_stream_response() + break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) + + self._refetch_message() + + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None + + db.session.commit() + db.session.close() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) + + self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ + if self._task_state.metadata else None + + db.session.commit() + db.session.close() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -345,31 +436,44 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._task_state.answer += delta_text yield self._message_to_stream_response(delta_text, self._message.id) elif isinstance(event, QueueMessageReplaceEvent): + # published by moderation yield self._message_replace_to_stream_response(answer=event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + elif isinstance(event, QueueAdvancedChatMessageEndEvent): + if not graph_runtime_state: + raise Exception('Graph runtime state not initialized.') + + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) + + # Save message + self._save_message(graph_runtime_state=graph_runtime_state) + + yield self._message_end_to_stream_response() else: continue - if publisher: - publisher.publish(None) + + if tts_publisher: + tts_publisher.publish(None) + if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self) -> None: + def _save_message(self, graph_runtime_state: GraphRuntimeState) -> None: """ Save message. :return: """ - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._refetch_message() self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ if self._task_state.metadata else None - if self._task_state.metadata and self._task_state.metadata.get('usage'): - usage = LLMUsage(**self._task_state.metadata['usage']) - + if graph_runtime_state.llm_usage: + usage = graph_runtime_state.llm_usage self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -404,26 +508,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc **extras ) - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } - def _handle_output_moderation_chunk(self, text: str) -> bool: """ Handle output moderation chunk. @@ -449,3 +533,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._output_moderation_handler.append_new_token(text) return False + + def _refetch_message(self) -> None: + """ + Refetch message. + :return: + """ + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if message: + self._message = message diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 1165314a7f..a196d36be5 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): def convert(cls, response: Union[ AppBlockingResponse, Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom): + ], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]: if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) diff --git a/api/core/app/apps/chatflow/__init__.py b/api/core/app/apps/chatflow/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/app/apps/chatflow/app_config_manager.py b/api/core/app/apps/chatflow/app_config_manager.py deleted file mode 100644 index c3d0e8ba03..0000000000 --- a/api/core/app/apps/chatflow/app_config_manager.py +++ /dev/null @@ -1,101 +0,0 @@ - -from core.app.app_config.base_app_config_manager import BaseAppConfigManager -from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager -from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager -from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager -from core.app.app_config.features.suggested_questions_after_answer.manager import ( - SuggestedQuestionsAfterAnswerConfigManager, -) -from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager -from models.model import App, AppMode -from models.workflow import Workflow - - -class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): - """ - Advanced Chatbot App Config Entity. - """ - pass - - -class AdvancedChatAppConfigManager(BaseAppConfigManager): - @classmethod - def get_app_config(cls, app_model: App, - workflow: Workflow) -> AdvancedChatAppConfig: - features_dict = workflow.features_dict - - app_mode = AppMode.value_of(app_model.mode) - app_config = AdvancedChatAppConfig( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - app_mode=app_mode, - workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) - ) - - return app_config - - @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: - """ - Validate for advanced chat app model config - - :param tenant_id: tenant id - :param config: app model config args - :param only_structure_validate: if True, only structure validation will be performed - """ - related_config_keys = [] - - # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False - ) - related_config_keys.extend(current_related_config_keys) - - # opening_statement - config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # suggested_questions_after_answer - config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) - related_config_keys.extend(current_related_config_keys) - - # speech_to_text - config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # text_to_speech - config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # return retriever resource - config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config) - related_config_keys.extend(current_related_config_keys) - - # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate - ) - related_config_keys.extend(current_related_config_keys) - - related_config_keys = list(set(related_config_keys)) - - # Filter out extra parameters - filtered_config = {key: config.get(key) for key in related_config_keys} - - return filtered_config - diff --git a/api/core/app/apps/chatflow/app_generator.py b/api/core/app/apps/chatflow/app_generator.py deleted file mode 100644 index ab21f6dbb2..0000000000 --- a/api/core/app/apps/chatflow/app_generator.py +++ /dev/null @@ -1,189 +0,0 @@ -import logging -import os -import uuid -from collections.abc import Generator -from typing import Union - -from pydantic import ValidationError - -import contexts -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager -from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException -from core.app.apps.chatflow.app_runner import AdvancedChatAppRunner -from core.app.apps.message_based_app_generator import MessageBasedAppGenerator -from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.ops.ops_trace_manager import TraceQueueManager -from extensions.ext_database import db -from models.account import Account -from models.model import App, Conversation, EndUser -from models.workflow import Workflow - -logger = logging.getLogger(__name__) - - -class AdvancedChatAppGenerator(MessageBasedAppGenerator): - def generate( - self, app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: - """ - Generate App response. - - :param app_model: App - :param workflow: Workflow - :param user: account or end user - :param args: request args - :param invoke_from: invoke from source - :param stream: is stream - """ - if not args.get('query'): - raise ValueError('query is required') - - query = args['query'] - if not isinstance(query, str): - raise ValueError('query must be a string') - - query = query.replace('\x00', '') - inputs = args['inputs'] - - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', False) - } - - # get conversation - conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) - - # parse files - files = args['files'] if args.get('files') else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) - else: - file_objs = [] - - # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) - - # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) - - if invoke_from == InvokeFrom.DEBUGGER: - # always enable retriever resource in debugger mode - app_config.additional_features.show_retrieve_source = True - - # init application generate entity - application_generate_entity = AdvancedChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), - query=query, - files=file_objs, - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager - ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) - - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - invoke_from=invoke_from, - application_generate_entity=application_generate_entity, - conversation=conversation, - stream=stream - ) - - def _generate(self, app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation = None, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: - is_first_conversation = False - if not conversation: - is_first_conversation = True - - # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) - - if is_first_conversation: - # update conversation features - conversation.override_model_configs = workflow.features - db.session.commit() - db.session.refresh(conversation) - - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id - ) - - try: - # chatbot app - runner = AdvancedChatAppRunner() - response = runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message - ) - except GenerateTaskStoppedException: - pass - except InvokeAuthorizationError: - raise - except ValidationError as e: - logger.exception("Validation Error when generating") - raise e - except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() - else: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': - logger.exception(e) - raise e - except InvokeError as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': - logger.exception("Error when generating") - raise e - except Exception as e: - logger.exception("Unknown Error when generating") - raise e - finally: - db.session.close() - - return AdvancedChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) diff --git a/api/core/app/apps/chatflow/app_runner.py b/api/core/app/apps/chatflow/app_runner.py deleted file mode 100644 index 3b6930cfc6..0000000000 --- a/api/core/app/apps/chatflow/app_runner.py +++ /dev/null @@ -1,422 +0,0 @@ -import logging -import os -import time -from collections.abc import Generator, Mapping -from typing import Any, Optional, cast - -from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner -from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, -) -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueAnnotationReplyEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueParallelBranchRunFailedEvent, - QueueParallelBranchRunStartedEvent, - QueueRetrieverResourcesEvent, - QueueStopEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.moderation.base import ModerationException -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import SystemVariable, UserFrom -from core.workflow.graph_engine.entities.event import ( - GraphRunFailedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - IterationRunFailedEvent, - IterationRunNextEvent, - IterationRunStartedEvent, - IterationRunSucceededEvent, - NodeRunFailedEvent, - NodeRunRetrieverResourceEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - ParallelBranchRunFailedEvent, - ParallelBranchRunStartedEvent, - ParallelBranchRunSucceededEvent, -) -from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db -from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow - -logger = logging.getLogger(__name__) - - -class AdvancedChatAppRunner(AppRunner): - """ - AdvancedChat Application Runner - """ - - def run(self, application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> Generator[AppQueueEvent, None, None]: - """ - Run application - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param conversation: conversation - :param message: message - :return: - """ - app_config = application_generate_entity.app_config - app_config = cast(AdvancedChatAppConfig, app_config) - - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() - if not app_record: - raise ValueError("App not found") - - workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id - - # moderation - if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id - ): - return - - # annotation reply - if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity - ): - return - - db.session.close() - - workflow_callbacks: list[WorkflowCallback] = [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): - workflow_callbacks.append(WorkflowLoggingCallback()) - - # RUN WORKFLOW - workflow_entry = WorkflowEntry( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - user_inputs=inputs, - system_inputs={ - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id - }, - call_depth=application_generate_entity.call_depth - ) - generator = workflow_entry.run( - callbacks=workflow_callbacks, - ) - - for event in generator: - if isinstance(event, GraphRunStartedEvent): - queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, GraphRunSucceededEvent): - queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, GraphRunFailedEvent): - queue_manager.publish( - QueueWorkflowFailedEvent(error=event.error), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, NodeRunStartedEvent): - queue_manager.publish( - QueueNodeStartedEvent( - node_id=event.node_id, - node_type=event.node_type, - node_data=event.node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - predecessor_node_id=event.predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, NodeRunSucceededEvent): - queue_manager.publish( - QueueNodeSucceededEvent( - node_id=event.node_id, - node_type=event.node_type, - node_data=event.node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, - execution_metadata=event.route_node_state.node_run_result.metadata - if event.route_node_state.node_run_result else {}, - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, NodeRunFailedEvent): - queue_manager.publish( - QueueNodeFailedEvent( - node_id=event.node_id, - node_type=event.node_type, - node_data=event.node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - inputs=event.route_node_state.node_run_result.inputs - if event.route_node_state.node_run_result else {}, - process_data=event.route_node_state.node_run_result.process_data - if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs - if event.route_node_state.node_run_result else {}, - error=event.route_node_state.node_run_result.error - if event.route_node_state.node_run_result - and event.route_node_state.node_run_result.error - else "Unknown error" - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, NodeRunStreamChunkEvent): - queue_manager.publish( - QueueTextChunkEvent( - text=event.chunk_content - ), PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, NodeRunRetrieverResourceEvent): - queue_manager.publish( - QueueRetrieverResourcesEvent( - retriever_resources=event.retriever_resources - ), PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, ParallelBranchRunStartedEvent): - queue_manager.publish( - QueueParallelBranchRunStartedEvent( - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, ParallelBranchRunSucceededEvent): - queue_manager.publish( - QueueParallelBranchRunStartedEvent( - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, ParallelBranchRunFailedEvent): - queue_manager.publish( - QueueParallelBranchRunFailedEvent( - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - error=event.error - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, IterationRunStartedEvent): - queue_manager.publish( - QueueIterationStartEvent( - node_id=event.iteration_node_id, - node_type=event.iteration_node_type, - node_data=event.iteration_node_data, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id, - metadata=event.metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, IterationRunNextEvent): - queue_manager.publish( - QueueIterationNextEvent( - node_id=event.iteration_node_id, - node_type=event.iteration_node_type, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - index=event.index, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - output=event.pre_iteration_output, - ), - PublishFrom.APPLICATION_MANAGER - ) - elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): - queue_manager.publish( - QueueIterationCompletedEvent( - node_id=event.iteration_node_id, - node_type=event.iteration_node_type, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error if isinstance(event, IterationRunFailedEvent) else None - ), - PublishFrom.APPLICATION_MANAGER - ) - - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == workflow_id - ).first() - - # return workflow - return workflow - - def handle_input_moderation( - self, queue_manager: AppQueueManager, - app_record: App, - app_generate_entity: AdvancedChatAppGenerateEntity, - inputs: Mapping[str, Any], - query: str, - message_id: str - ) -> bool: - """ - Handle input moderation - :param queue_manager: application queue manager - :param app_record: app record - :param app_generate_entity: application generate entity - :param inputs: inputs - :param query: query - :param message_id: message id - :return: - """ - try: - # process sensitive_word_avoidance - _, inputs, query = self.moderation_for_inputs( - app_id=app_record.id, - tenant_id=app_generate_entity.app_config.tenant_id, - app_generate_entity=app_generate_entity, - inputs=inputs, - query=query, - message_id=message_id, - ) - except ModerationException as e: - self._stream_output( - queue_manager=queue_manager, - text=str(e), - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION - ) - return True - - return False - - def handle_annotation_reply(self, app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity) -> bool: - """ - Handle annotation reply - :param app_record: app record - :param message: message - :param query: query - :param queue_manager: application queue manager - :param app_generate_entity: application generate entity - """ - # annotation reply - annotation_reply = self.query_app_annotations_to_reply( - app_record=app_record, - message=message, - query=query, - user_id=app_generate_entity.user_id, - invoke_from=app_generate_entity.invoke_from - ) - - if annotation_reply: - queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER - ) - - self._stream_output( - queue_manager=queue_manager, - text=annotation_reply.content, - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY - ) - return True - - return False - - def _stream_output(self, queue_manager: AppQueueManager, - text: str, - stream: bool, - stopped_by: QueueStopEvent.StopBy) -> None: - """ - Direct output - :param queue_manager: application queue manager - :param text: text - :param stream: stream - :return: - """ - if stream: - index = 0 - for token in text: - queue_manager.publish( - QueueTextChunkEvent( - text=token - ), PublishFrom.APPLICATION_MANAGER - ) - index += 1 - time.sleep(0.01) - else: - queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) - - queue_manager.publish( - QueueStopEvent(stopped_by=stopped_by), - PublishFrom.APPLICATION_MANAGER - ) diff --git a/api/core/app/apps/chatflow/generate_task_pipeline.py b/api/core/app/apps/chatflow/generate_task_pipeline.py deleted file mode 100644 index 050595c786..0000000000 --- a/api/core/app/apps/chatflow/generate_task_pipeline.py +++ /dev/null @@ -1,450 +0,0 @@ -import json -import logging -import time -from collections.abc import Generator -from typing import Any, Optional, Union - -from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME -from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, -) -from core.app.entities.queue_entities import ( - QueueAdvancedChatMessageEndEvent, - QueueAnnotationReplyEvent, - QueueErrorEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueMessageReplaceEvent, - QueueNodeFailedEvent, - QueueNodeSucceededEvent, - QueuePingEvent, - QueueRetrieverResourcesEvent, - QueueStopEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, -) -from core.app.entities.task_entities import ( - AdvancedChatTaskState, - ChatbotAppBlockingResponse, - ChatbotAppStreamResponse, - ErrorStreamResponse, - MessageAudioEndStreamResponse, - MessageAudioStreamResponse, - MessageEndStreamResponse, - StreamResponse, -) -from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline -from core.app.task_pipeline.message_cycle_manage import MessageCycleManage -from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder -from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable -from core.workflow.graph_engine.entities.event import GraphRunStartedEvent, NodeRunStartedEvent -from events.message_event import message_was_created -from extensions.ext_database import db -from models.account import Account -from models.model import Conversation, EndUser, Message -from models.workflow import ( - Workflow, - WorkflowRunStatus, -) - -logger = logging.getLogger(__name__) - - -class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): - """ - AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. - """ - _task_state: AdvancedChatTaskState - _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow: Workflow - _user: Union[Account, EndUser] - _workflow_system_variables: dict[SystemVariable, Any] - _iteration_nested_relations: dict[str, list[str]] - - def __init__( - self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool - ) -> None: - """ - Initialize AdvancedChatAppGenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - """ - super().__init__(application_generate_entity, queue_manager, user, stream) - - if isinstance(self._user, EndUser): - user_id = self._user.session_id - else: - user_id = self._user.id - - self._workflow = workflow - self._conversation = conversation - self._message = message - self._workflow_system_variables = { - SystemVariable.QUERY: message.query, - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id - } - - self._task_state = AdvancedChatTaskState( - usage=LLMUsage.empty_usage() - ) - - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) - self._stream_generate_routes = self._get_stream_generate_routes() - self._conversation_name_generate_thread = None - - def process(self): - """ - Process generate task pipeline. - :return: - """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - - # start generate conversation name thread - self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query - ) - - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) - if self._stream: - return self._to_stream_response(generator) - else: - return self._to_blocking_response(generator) - - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse: - """ - Process blocking response. - :return: - """ - for stream_response in generator: - if isinstance(stream_response, ErrorStreamResponse): - raise stream_response.err - elif isinstance(stream_response, MessageEndStreamResponse): - extras = {} - if stream_response.metadata: - extras['metadata'] = stream_response.metadata - - return ChatbotAppBlockingResponse( - task_id=stream_response.task_id, - data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, - answer=self._task_state.answer, - created_at=int(self._message.created_at.timestamp()), - **extras - ) - ) - else: - continue - - raise Exception('Queue listening stopped unexpectedly.') - - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: - """ - To stream response. - :return: - """ - for stream_response in generator: - yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response - ) - - def _listenAudioMsg(self, publisher, task_id: str): - if not publisher: - return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() - if audio_msg and audio_msg.status != "finish": - return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) - return None - - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - - publisher = None - task_id = self._application_generate_entity.task_id - tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict - - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): - while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) - if audio_response: - yield audio_response - else: - break - yield response - - start_listener_time = time.time() - # timeout - while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: - try: - if not publisher: - break - audio_trunk = publisher.checkAndGetAudio() - if audio_trunk is None: - # release cpu - # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) - time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) - continue - if audio_trunk.status == "finish": - break - else: - start_listener_time = time.time() - yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) - except Exception as e: - logger.error(e) - break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) - - def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None - ) -> Generator[StreamResponse, None, None]: - """ - Process stream response. - :return: - """ - for message in self._queue_manager.listen(): - if publisher: - publisher.publish(message=message) - event = message.event - - if isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) - yield self._error_to_stream_response(err) - break - elif isinstance(event, GraphRunStartedEvent): - workflow_run = self._handle_workflow_start() - - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - self._message.workflow_run_id = workflow_run.id - - db.session.commit() - db.session.refresh(self._message) - db.session.close() - - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run - ) - elif isinstance(event, NodeRunStartedEvent): - workflow_node_execution = self._handle_node_start(event) - - yield self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution - ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - - yield self._workflow_node_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution - ) - - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, conversation_id=self._conversation.id, trace_manager=trace_manager - ) - if workflow_run: - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run - ) - - if workflow_run.status == WorkflowRunStatus.FAILED.value: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - - if isinstance(event, QueueStopEvent): - # Save message - self._save_message() - - yield self._message_end_to_stream_response() - break - else: - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) - - # Save message - self._save_message() - - yield self._message_end_to_stream_response() - elif isinstance(event, QueueRetrieverResourcesEvent): - self._handle_retriever_resources(event) - elif isinstance(event, QueueAnnotationReplyEvent): - self._handle_annotation_reply(event) - elif isinstance(event, QueueTextChunkEvent): - delta_text = event.text - if delta_text is None: - continue - - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) - if should_direct_answer: - continue - - self._task_state.answer += delta_text - yield self._message_to_stream_response(delta_text, self._message.id) - elif isinstance(event, QueueMessageReplaceEvent): - yield self._message_replace_to_stream_response(answer=event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() - else: - continue - if publisher: - publisher.publish(None) - if self._conversation_name_generate_thread: - self._conversation_name_generate_thread.join() - - def _save_message(self) -> None: - """ - Save message. - :return: - """ - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - - self._message.answer = self._task_state.answer - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None - - if self._task_state.metadata and self._task_state.metadata.get('usage'): - usage = LLMUsage(**self._task_state.metadata['usage']) - - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.total_price = usage.total_price - self._message.currency = usage.currency - - db.session.commit() - - message_was_created.send( - self._message, - application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras - ) - - def _message_end_to_stream_response(self) -> MessageEndStreamResponse: - """ - Message end to stream response. - :return: - """ - extras = {} - if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata - - return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras - ) - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } - - def _handle_output_moderation_chunk(self, text: str) -> bool: - """ - Handle output moderation chunk. - :param text: text - :return: True if output moderation should direct output, otherwise False - """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): - # stop subscribe new token when output moderation should direct output - self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE - ) - - self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE - ) - return True - else: - self._output_moderation_handler.append_new_token(text) - - return False diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 6b22d01340..618a91a999 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -4,7 +4,6 @@ from typing import Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig -from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.entities.app_invoke_entities import ( InvokeFrom, @@ -12,7 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.entities.node_entities import SystemVariable, UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App, EndUser from models.workflow import Workflow @@ -57,17 +56,14 @@ class WorkflowAppRunner: db.session.close() - workflow_callbacks: list[WorkflowCallback] = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks: list[WorkflowCallback] = [] if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): workflow_callbacks.append(WorkflowLoggingCallback()) # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run( + workflow_entry = WorkflowEntry() + workflow_entry.run( workflow=workflow, user_id=application_generate_entity.user_id, user_from=UserFrom.ACCOUNT @@ -100,13 +96,10 @@ class WorkflowAppRunner: if not workflow: raise ValueError("Workflow not initialized") - workflow_callbacks = [WorkflowEventTriggerCallback( - queue_manager=queue_manager, - workflow=workflow - )] + workflow_callbacks = [] - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( + workflow_entry = WorkflowEntry() + workflow_entry.single_step_run_iteration_workflow_node( workflow=workflow, node_id=node_id, user_id=user_id, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2b4362150f..1b3379d39a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -36,14 +36,12 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, - WorkflowStreamGenerateNodes, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType, SystemVariable -from core.workflow.nodes.end.end_node import EndNode +from core.workflow.entities.node_entities import SystemVariable from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -51,7 +49,6 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, ) @@ -67,7 +64,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity _workflow_system_variables: dict[SystemVariable, Any] - _iteration_nested_relations: dict[str, list[str]] def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, workflow: Workflow, @@ -95,11 +91,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa SystemVariable.USER_ID: user_id } - self._task_state = WorkflowTaskState( - iteration_nested_node_ids=[] - ) - self._stream_generate_nodes = self._get_stream_generate_nodes() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) + self._task_state = WorkflowTaskState() def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -128,8 +120,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, WorkflowFinishStreamResponse): - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_run = self._task_state.workflow_run + if not workflow_run: + raise Exception('Workflow run not found.') response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -161,8 +154,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa :return: """ for stream_response in generator: + if not self._task_state.workflow_run: + raise Exception('Workflow run not found.') + yield WorkflowAppStreamResponse( - workflow_run_id=self._task_state.workflow_run_id, + workflow_run_id=self._task_state.workflow_run.id, stream_response=stream_response ) @@ -234,20 +230,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) - - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes: - self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id] - - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() + workflow_node_execution = self._handle_execution_node_start(event) yield self._workflow_node_start_to_stream_response( event=event, @@ -268,13 +257,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa error=f'Child node failed: {event.error}' ) elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) self._handle_iteration_operation(event) elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): @@ -294,11 +276,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue - self._task_state.answer += delta_text yield self._text_chunk_to_stream_response(delta_text) elif isinstance(event, QueueMessageReplaceEvent): @@ -364,170 +341,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa task_id=self._application_generate_entity.task_id, text=TextReplaceStreamResponse.Data(text=text) ) - - def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]: - """ - Get stream generate nodes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - end_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.END.value - ] - - # parse stream output node value selectors of end nodes - stream_generate_routes = {} - for node_config in end_node_configs: - # get generate route for stream output - end_node_id = node_config['id'] - generate_nodes = EndNode.extract_generate_nodes(graph, node_config) - start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes( - end_node_id=end_node_id, - stream_node_ids=generate_nodes - ) - - return stream_generate_routes - - def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get end start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids - - for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items(): - if node_id not in stream_node_ids: - continue - - node_execution_info = self._task_state.ran_node_execution_infos[node_id] - - # get chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first() - - if not route_chunk_node_execution: - continue - - outputs = route_chunk_node_execution.outputs_dict - - if not outputs: - continue - - # get value from outputs - text = outputs.get('text') - - if text: - self._task_state.answer += text - yield self._text_chunk_to_stream_response(text) - - db.session.close() - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return False - - if 'node_id' not in event.metadata: - return False - - node_id = event.metadata.get('node_id') - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - if node_id not in self._task_state.current_stream_generate_state.stream_node_ids: - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - return True - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } - \ No newline at end of file diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index cd858419d1..7caa2a7049 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -133,7 +133,7 @@ class WorkflowLoggingCallback(WorkflowCallback): self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') self.print_text(f"Node ID: {route_node_state.node_id}", color='green') - self.print_text(f"Type: {node_type.value}", color='green') + self.print_text(f"Type: {node_type}", color='green') if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result @@ -145,7 +145,7 @@ class WorkflowLoggingCallback(WorkflowCallback): self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color='green') self.print_text( - f"Metadata: {jsonable_encoder(node_run_result.execution_metadata) if node_run_result.execution_metadata else ''}", + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", color='green') def on_workflow_node_execute_failed( @@ -166,7 +166,7 @@ class WorkflowLoggingCallback(WorkflowCallback): self.print_text("\n[on_workflow_node_execute_failed]", color='red') self.print_text(f"Node ID: {route_node_state.node_id}", color='red') - self.print_text(f"Type: {node_type.value}", color='red') + self.print_text(f"Type: {node_type}", color='red') if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 6ecdd91690..c4882ff669 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from datetime import datetime from enum import Enum from typing import Any, Optional @@ -7,6 +7,7 @@ from pydantic import BaseModel, field_validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState class QueueEvent(str, Enum): @@ -60,6 +61,7 @@ class QueueIterationStartEvent(AppQueueEvent): QueueIterationStartEvent entity """ event: QueueEvent = QueueEvent.ITERATION_START + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData @@ -67,11 +69,12 @@ class QueueIterationStartEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None + inputs: Optional[dict[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[Mapping[str, Any]] = None + metadata: Optional[dict[str, Any]] = None class QueueIterationNextEvent(AppQueueEvent): """ @@ -80,8 +83,10 @@ class QueueIterationNextEvent(AppQueueEvent): event: QueueEvent = QueueEvent.ITERATION_NEXT index: int + node_execution_id: str node_id: str node_type: NodeType + node_data: BaseNodeData parallel_id: Optional[str] = None """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None @@ -108,17 +113,20 @@ class QueueIterationCompletedEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.ITERATION_COMPLETED + node_execution_id: str node_id: str node_type: NodeType + node_data: BaseNodeData parallel_id: Optional[str] = None """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None steps: int = 0 error: Optional[str] = None @@ -130,7 +138,6 @@ class QueueTextChunkEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - metadata: Optional[dict] = None class QueueAgentMessageEvent(AppQueueEvent): @@ -185,6 +192,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent): QueueWorkflowStartedEvent entity """ event: QueueEvent = QueueEvent.WORKFLOW_STARTED + graph_runtime_state: GraphRuntimeState class QueueWorkflowSucceededEvent(AppQueueEvent): @@ -192,6 +200,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent): QueueWorkflowSucceededEvent entity """ event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED + outputs: Optional[dict[str, Any]] = None class QueueWorkflowFailedEvent(AppQueueEvent): @@ -208,6 +217,7 @@ class QueueNodeStartedEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_STARTED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData @@ -217,6 +227,7 @@ class QueueNodeStartedEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + start_at: datetime class QueueNodeSucceededEvent(AppQueueEvent): @@ -225,6 +236,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_SUCCEEDED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData @@ -232,11 +244,12 @@ class QueueNodeSucceededEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: Optional[str] = None @@ -247,6 +260,7 @@ class QueueNodeFailedEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.NODE_FAILED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData @@ -254,10 +268,11 @@ class QueueNodeFailedEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + start_at: datetime - inputs: Optional[Mapping[str, Any]] = None - process_data: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index ddf8200c77..bc3ec7980c 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -3,30 +3,11 @@ from typing import Any, Optional from pydantic import BaseModel, ConfigDict -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType from models.workflow import WorkflowNodeExecutionStatus -class WorkflowStreamGenerateNodes(BaseModel): - """ - WorkflowStreamGenerateNodes entity - """ - end_node_id: str - stream_node_ids: list[str] - - -class NodeExecutionInfo(BaseModel): - """ - NodeExecutionInfo entity - """ - workflow_node_execution_id: str - node_type: NodeType - start_at: float - - class TaskState(BaseModel): """ TaskState entity @@ -47,27 +28,6 @@ class WorkflowTaskState(TaskState): """ answer: str = "" - workflow_run_id: Optional[str] = None - start_at: Optional[float] = None - total_tokens: int = 0 - total_steps: int = 0 - - ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} - latest_node_execution_info: Optional[NodeExecutionInfo] = None - - current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None - - iteration_nested_node_ids: list[str] = None - - -class AdvancedChatTaskState(WorkflowTaskState): - """ - AdvancedChatTaskState entity - """ - usage: LLMUsage - - current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None - class StreamEvent(Enum): """ @@ -398,8 +358,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse): title: str outputs: Optional[dict] = None created_at: int - extras: dict = None - inputs: dict = None + extras: Optional[dict] = None + inputs: Optional[dict] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float @@ -552,25 +512,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): workflow_run_id: str data: Data - - -class WorkflowIterationState(BaseModel): - """ - WorkflowIterationState entity - """ - - class Data(BaseModel): - """ - Data entity - """ - parent_iteration_id: Optional[str] = None - iteration_id: str - current_index: int - iteration_steps_boundary: list[int] = None - node_execution_id: str - started_at: float - inputs: dict = None - total_tokens: int = 0 - node_data: BaseNodeData - - current_iterations: dict[str, Data] = None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index a3c1fb5824..39ca88869a 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline: self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception: + def _handle_error(self, event: QueueErrorEvent, message: Message) -> Exception: """ Handle error event. :param event: event @@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline: err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) if message: - message = db.session.query(Message).filter(Message.id == message.id).first() - err_desc = self._error_to_desc(err) - message.status = 'error' - message.error = err_desc + refetch_message = db.session.query(Message).filter(Message.id == message.id).first() - db.session.commit() + if refetch_message: + err_desc = self._error_to_desc(err) + refetch_message.status = 'error' + refetch_message.error = err_desc + + db.session.commit() return err - def _error_to_desc(cls, e: Exception) -> str: + def _error_to_desc(self, e: Exception) -> str: """ Error to desc. :param e: exception diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 76c50809cf..8ff50dd174 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, @@ -16,11 +15,11 @@ from core.app.entities.queue_entities import ( QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator from core.tools.tool_file_manager import ToolFileManager @@ -36,7 +35,7 @@ class MessageCycleManage: AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] - _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + _task_state: Union[EasyUITaskState, WorkflowTaskState] def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: """ @@ -45,6 +44,9 @@ class MessageCycleManage: :param query: query :return: thread """ + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + return None + is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) @@ -52,7 +54,7 @@ class MessageCycleManage: if auto_generate_conversation_name and is_first_message: # start generate thread thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), + 'flask_app': current_app._get_current_object(), # type: ignore 'conversation_id': conversation.id, 'query': query }) @@ -75,6 +77,9 @@ class MessageCycleManage: .first() ) + if not conversation: + return + if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app if not app_model: @@ -121,34 +126,13 @@ class MessageCycleManage: if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata['retriever_resources'] = event.retriever_resources - def _get_response_metadata(self) -> dict: - """ - Get response metadata by invoke from. - :return: - """ - metadata = {} - - # show_retrieve_source - if 'retriever_resources' in self._task_state.metadata: - metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] - - # show annotation reply - if 'annotation_reply' in self._task_state.metadata: - metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] - - # show usage - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['usage'] = self._task_state.metadata['usage'] - - return metadata - def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ Message file to stream response. :param event: event :return: """ - message_file: MessageFile = ( + message_file = ( db.session.query(MessageFile) .filter(MessageFile.id == event.message_file_id) .first() diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 513fc692ff..6c2d3bab54 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,32 +1,34 @@ import json import time from datetime import datetime, timezone -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - NodeExecutionInfo, + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, + WorkflowTaskState, ) -from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.entities.node_entities import NodeType, SystemVariable from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -42,51 +44,54 @@ from models.workflow import ( ) -class WorkflowCycleManage(WorkflowIterationCycleManage): - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 +class WorkflowCycleManage: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: WorkflowTaskState + _workflow_system_variables: dict[SystemVariable, Any] + + def _handle_workflow_run_start(self) -> WorkflowRun: + max_sequence = ( + db.session.query(db.func.max(WorkflowRun.sequence_number)) + .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) + .filter(WorkflowRun.app_id == self._workflow.app_id) + .scalar() + or 0 + ) new_sequence_number = max_sequence + 1 - inputs = {**user_inputs} - for key, value in (system_inputs or {}).items(): + inputs = {**self._application_generate_entity.inputs} + for key, value in (self._workflow_system_variables or {}).items(): if key.value == 'conversation': continue inputs[f'sys.{key.value}'] = value - inputs = WorkflowEngineManager.handle_special_values(inputs) + + inputs = WorkflowEntry.handle_special_values(inputs) + + triggered_from= ( + WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN + ) # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps(inputs), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id + workflow_run = WorkflowRun() + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING.value + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value ) + workflow_run.created_by = self._user.id db.session.add(workflow_run) db.session.commit() @@ -94,15 +99,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): db.session.close() return workflow_run - - def _workflow_run_success( - self, workflow_run: WorkflowRun, + + def _handle_workflow_run_success( + self, + workflow_run: WorkflowRun, start_at: float, total_tokens: int, total_steps: int, outputs: Optional[str] = None, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run success @@ -114,6 +120,8 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): :param conversation_id: conversation id :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value workflow_run.outputs = outputs workflow_run.elapsed_time = time.perf_counter() - start_at @@ -123,7 +131,6 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): db.session.commit() db.session.refresh(workflow_run) - db.session.close() if trace_manager: trace_manager.add_trace_task( @@ -134,17 +141,20 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): ) ) + db.session.close() + return workflow_run - def _workflow_run_failed( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_failed( + self, + workflow_run: WorkflowRun, start_at: float, total_tokens: int, total_steps: int, status: WorkflowRunStatus, error: str, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run failed @@ -156,6 +166,8 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): :param error: error message :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = status.value workflow_run.error = error workflow_run.elapsed_time = time.perf_counter() - start_at @@ -177,40 +189,25 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): ) return workflow_run - - def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run: workflow run - :param node_id: node id - :param node_type: node type - :param node_title: node title - :param node_run_index: run index - :param predecessor_node_id: predecessor node id if exists - :return: - """ + + def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.add(workflow_node_execution) db.session.commit() @@ -219,32 +216,25 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: """ Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param inputs: inputs - :param process_data: process data - :param outputs: outputs - :param execution_metadata: execution metadata + :param event: queue node succeeded event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.elapsed_time = time.perf_counter() - event.start_at.timestamp() workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -253,42 +243,38 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return workflow_node_execution - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None - ) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: """ Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message + :param event: queue node failed event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.error = event.error + workflow_node_execution.elapsed_time = time.perf_counter() - event.start_at.timestamp() workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None + workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None db.session.commit() db.session.refresh(workflow_node_execution) db.session.close() return workflow_node_execution + + ################################################# + # to stream responses # + ################################################# - def _workflow_start_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: + def _workflow_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowStartStreamResponse: """ Workflow start to stream response. :param task_id: task id @@ -302,13 +288,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, - created_at=int(workflow_run.created_at.timestamp()) - ) + inputs=workflow_run.inputs_dict or {}, + created_at=int(workflow_run.created_at.timestamp()), + ), ) - def _workflow_finish_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: + def _workflow_finish_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowFinishStreamResponse: """ Workflow finish to stream response. :param task_id: task id @@ -320,16 +307,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): created_by_account = workflow_run.created_by_account if created_by_account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + 'id': created_by_account.id, + 'name': created_by_account.name, + 'email': created_by_account.email, } else: created_by_end_user = workflow_run.created_by_end_user if created_by_end_user: created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, + 'id': created_by_end_user.id, + 'user': created_by_end_user.session_id, } return WorkflowFinishStreamResponse( @@ -348,14 +335,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), + ), ) - def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution) \ - -> NodeStartStreamResponse: + def _workflow_node_start_to_stream_response( + self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + ) -> NodeStartStreamResponse: """ Workflow node start to stream response. :param event: queue node started event @@ -374,8 +360,8 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs_dict, - created_at=int(workflow_node_execution.created_at.timestamp()) - ) + created_at=int(workflow_node_execution.created_at.timestamp()), + ), ) # extras logic @@ -384,13 +370,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): response.data.extras['icon'] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, - provider_id=node_data.provider_id + provider_id=node_data.provider_id, ) return response - def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ - -> NodeFinishStreamResponse: + def _workflow_node_finish_to_stream_response( + self, task_id: str, workflow_node_execution: WorkflowNodeExecution + ) -> NodeFinishStreamResponse: """ Workflow node finish to stream response. :param task_id: task id @@ -416,184 +403,90 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): execution_metadata=workflow_node_execution.execution_metadata_dict, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict) + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + ), + ) + + def _workflow_iteration_start_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: + """ + Workflow iteration start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration start event + :return: + """ + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {} ) ) - - def _handle_workflow_start(self) -> WorkflowRun: - self._task_state.start_at = time.perf_counter() - - workflow_run = self._init_workflow_run( - workflow=self._workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN, - user=self._user, - user_inputs=self._application_generate_entity.inputs, - system_inputs=self._workflow_system_variables + + def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse: + """ + Workflow iteration next to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration next event + :return: + """ + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={} + ) ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run - - def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=event.node_type, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._task_state.total_steps += 1 - - db.session.close() - - return workflow_node_execution - - def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() - - execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None - - if self._iteration_state and self._iteration_state.current_iterations: - if not execution_metadata: - execution_metadata = {} - current_iteration_data = None - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if data.parent_iteration_id == None: - current_iteration_data = data - break - - if current_iteration_data: - execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id - execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index - - if isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - inputs=event.inputs, - process_data=event.process_data, + + def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse: + """ + Workflow iteration completed to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration completed event + :return: + """ + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, outputs=event.outputs, - execution_metadata=execution_metadata + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=time.perf_counter() - event.start_at.timestamp(), + total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps ) - - if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - self._task_state.total_tokens += ( - int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) - - if self._iteration_state: - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - else: - workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - error=event.error, - inputs=event.inputs, - process_data=event.process_data, - outputs=event.outputs, - execution_metadata=execution_metadata - ) - - db.session.close() - - return workflow_node_execution - - def _handle_workflow_finished( - self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None - ) -> Optional[WorkflowRun]: - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - if not workflow_run: - return None - - if conversation_id is None: - conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id') - if isinstance(event, QueueStopEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.', - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - latest_node_execution_info = self._task_state.latest_node_execution_info - if latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first() - if (workflow_node_execution - and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value): - self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=latest_node_execution_info.start_at, - error='Workflow stopped.' - ) - elif isinstance(event, QueueWorkflowFailedEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - else: - if self._task_state.latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() - outputs = workflow_node_execution.outputs - else: - outputs = None - - workflow_run = self._workflow_run_success( - workflow_run=workflow_run, - start_at=self._task_state.start_at, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - outputs=outputs, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run + ) def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: """ @@ -650,3 +543,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage): return value.to_dict() return None + + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Refetch workflow run + :param workflow_run_id: workflow run id + :return: + """ + workflow_run = db.session.query(WorkflowRun).filter( + WorkflowRun.id == workflow_run_id).first() + + if not workflow_run: + raise Exception(f'Workflow run not found: {workflow_run_id}') + + return workflow_run + + def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + """ + Refetch workflow node execution + :param node_execution_id: workflow node execution id + :return: + """ + workflow_node_execution = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id, + WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id, + WorkflowNodeExecution.workflow_id == self._workflow.id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.node_execution_id == node_execution_id, + ) + .first() + ) + + if not workflow_node_execution: + raise Exception(f'Workflow node execution not found: {node_execution_id}') + + return workflow_node_execution \ No newline at end of file diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py deleted file mode 100644 index 545f31fddf..0000000000 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Any, Union - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.entities.node_entities import SystemVariable -from models.account import Account -from models.model import EndUser -from models.workflow import Workflow - - -class WorkflowCycleStateManager: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] \ No newline at end of file diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py deleted file mode 100644 index aff1870714..0000000000 --- a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py +++ /dev/null @@ -1,290 +0,0 @@ -import json -import time -from collections.abc import Generator -from datetime import datetime, timezone -from typing import Optional, Union - -from core.app.entities.queue_entities import ( - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, -) -from core.app.entities.task_entities import ( - IterationNodeCompletedStreamResponse, - IterationNodeNextStreamResponse, - IterationNodeStartStreamResponse, - NodeExecutionInfo, - WorkflowIterationState, -) -from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager -from core.workflow.entities.node_entities import NodeType -from core.workflow.workflow_engine_manager import WorkflowEngineManager -from extensions.ext_database import db -from models.workflow import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, -) - - -class WorkflowIterationCycleManage(WorkflowCycleStateManager): - _iteration_state: WorkflowIterationState = None - - def _init_iteration_state(self) -> WorkflowIterationState: - if not self._iteration_state: - self._iteration_state = WorkflowIterationState( - current_iterations={} - ) - - def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ - -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: - """ - Handle iteration to stream response - :param task_id: task id - :param event: iteration event - :return: - """ - if isinstance(event, QueueIterationStartEvent): - return IterationNodeStartStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeStartStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - created_at=int(time.time()), - extras={}, - inputs=event.inputs, - metadata=event.metadata - ) - ) - elif isinstance(event, QueueIterationNextEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeNextStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeNextStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - index=event.index, - pre_iteration_output=event.output, - created_at=int(time.time()), - extras={} - ) - ) - elif isinstance(event, QueueIterationCompletedEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - outputs=event.outputs, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - error=None, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) - - def _init_iteration_execution_from_workflow_run(self, - workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - inputs=json.dumps(inputs) if inputs else None, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - execution_metadata=json.dumps({ - 'started_run_index': node_run_index + 1, - 'current_index': 0, - 'steps_boundary': [], - }), - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution - - def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: - if isinstance(event, QueueIterationStartEvent): - return self._handle_iteration_started(event) - elif isinstance(event, QueueIterationNextEvent): - return self._handle_iteration_next(event) - elif isinstance(event, QueueIterationCompletedEvent): - return self._handle_iteration_completed(event) - - def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: - self._init_iteration_state() - - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_iteration_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=NodeType.ITERATION, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( - parent_iteration_id=None, - iteration_id=event.node_id, - current_index=0, - iteration_steps_boundary=[], - node_execution_id=workflow_node_execution.id, - started_at=time.perf_counter(), - inputs=event.inputs, - total_tokens=0, - node_data=event.node_data - ) - - db.session.close() - - return workflow_node_execution - - def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: - if event.node_id not in self._iteration_state.current_iterations: - return - current_iteration = self._iteration_state.current_iterations[event.node_id] - current_iteration.current_index = event.index - current_iteration.iteration_steps_boundary.append(event.node_run_index) - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['current_index'] = event.index - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - db.session.close() - - def _handle_iteration_completed(self, event: QueueIterationCompletedEvent): - if event.node_id not in self._iteration_state.current_iterations: - return - - current_iteration = self._iteration_state.current_iterations[event.node_id] - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - # remove current iteration - self._iteration_state.current_iterations.pop(event.node_id, None) - - # set latest node execution info - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.latest_node_execution_info = latest_node_execution_info - - db.session.close() - - def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: - """ - Handle iteration exception - """ - if not self._iteration_state or not self._iteration_state.current_iterations: - return - - for node_id, current_iteration in self._iteration_state.current_iterations.items(): - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - db.session.commit() - db.session.close() - - yield IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=node_id, - node_id=node_id, - node_type=NodeType.ITERATION.value, - title=current_iteration.node_data.title, - outputs={}, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index b5bd9e267a..59a4c103a2 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -63,6 +63,39 @@ class LLMUsage(ModelUsage): latency=0.0 ) + def plus(self, other: 'LLMUsage') -> 'LLMUsage': + """ + Add two LLMUsage instances together. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + if self.total_tokens == 0: + return other + else: + return LLMUsage( + prompt_tokens=self.prompt_tokens + other.prompt_tokens, + prompt_unit_price=other.prompt_unit_price, + prompt_price_unit=other.prompt_price_unit, + prompt_price=self.prompt_price + other.prompt_price, + completion_tokens=self.completion_tokens + other.completion_tokens, + completion_unit_price=other.completion_unit_price, + completion_price_unit=other.completion_price_unit, + completion_price=self.completion_price + other.completion_price, + total_tokens=self.total_tokens + other.total_tokens, + total_price=self.total_price + other.total_price, + currency=other.currency, + latency=self.latency + other.latency + ) + + def __add__(self, other: 'LLMUsage') -> 'LLMUsage': + """ + Overload the + operator to add two LLMUsage instances. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + return self.plus(other) class LLMResult(BaseModel): """ diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 9a4d8db4e2..69e28770c3 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -34,13 +34,13 @@ class OutputModeration(BaseModel): final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) - def should_direct_output(self): + def should_direct_output(self) -> bool: return self.final_output is not None - def get_final_output(self): - return self.final_output + def get_final_output(self) -> str: + return self.final_output or "" - def append_new_token(self, token: str): + def append_new_token(self, token: str) -> None: self.buffer += token if not self.thread: diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 5f165d2e42..6cc639f55f 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,9 +1,9 @@ -from collections.abc import Mapping from enum import Enum from typing import Any, Optional from pydantic import BaseModel +from core.model_runtime.entities.llm_entities import LLMUsage from models.workflow import WorkflowNodeExecutionStatus @@ -83,10 +83,11 @@ class NodeRunResult(BaseModel): """ status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING - inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[Mapping[str, Any]] = None # process data - outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata + inputs: Optional[dict[str, Any]] = None # node inputs + process_data: Optional[dict[str, Any]] = None # process data + outputs: Optional[dict[str, Any]] = None # node outputs + metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 04303cb293..eae6ffec02 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from datetime import datetime from typing import Any, Optional from pydantic import BaseModel, Field @@ -26,7 +26,8 @@ class GraphRunStartedEvent(BaseGraphEvent): class GraphRunSucceededEvent(BaseGraphEvent): - pass + outputs: Optional[dict[str, Any]] = None + """outputs""" class GraphRunFailedEvent(BaseGraphEvent): @@ -39,6 +40,7 @@ class GraphRunFailedEvent(BaseGraphEvent): class BaseNodeEvent(GraphEngineEvent): + id: str = Field(..., description="node execution id") node_id: str = Field(..., description="node id") node_type: NodeType = Field(..., description="node type") node_data: BaseNodeData = Field(..., description="node data") @@ -47,7 +49,8 @@ class BaseNodeEvent(GraphEngineEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" - in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration") + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class NodeRunStartedEvent(BaseNodeEvent): @@ -82,7 +85,8 @@ class NodeRunFailedEvent(BaseNodeEvent): class BaseParallelBranchEvent(GraphEngineEvent): parallel_id: str = Field(..., description="parallel id") parallel_start_node_id: str = Field(..., description="parallel start node id") - in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration") + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): @@ -103,6 +107,7 @@ class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): class BaseIterationEvent(GraphEngineEvent): + iteration_id: str = Field(..., description="iteration node execution id") iteration_node_id: str = Field(..., description="iteration node id") iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") iteration_node_data: BaseNodeData = Field(..., description="node data") @@ -113,8 +118,9 @@ class BaseIterationEvent(GraphEngineEvent): class IterationRunStartedEvent(BaseIterationEvent): - inputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None predecessor_node_id: Optional[str] = None @@ -124,16 +130,18 @@ class IterationRunNextEvent(BaseIterationEvent): class IterationRunSucceededEvent(BaseIterationEvent): - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None steps: int = 0 class IterationRunFailedEvent(BaseIterationEvent): - inputs: Optional[Mapping[str, Any]] = None - outputs: Optional[Mapping[str, Any]] = None - metadata: Optional[Mapping[str, Any]] = None + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None steps: int = 0 error: str = Field(..., description="failed reason") diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index c1e058065f..014e406011 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -24,6 +24,8 @@ class GraphParallel(BaseModel): start_from_node_id: str = Field(..., description="start from node id") parent_parallel_id: Optional[str] = None """parent parallel id""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id""" end_to_node_id: Optional[str] = None """end to node id""" @@ -101,7 +103,7 @@ class Graph(BaseModel): # parse run condition run_condition = None - if edge_config.get('sourceHandle'): + if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source': run_condition = RunCondition( type='branch_identify', branch_identify=edge_config.get('sourceHandle') @@ -176,7 +178,8 @@ class Graph(BaseModel): # init end stream param end_stream_param = EndStreamGeneratorRouter.init( node_id_config_mapping=node_id_config_mapping, - reverse_edge_mapping=reverse_edge_mapping + reverse_edge_mapping=reverse_edge_mapping, + node_parallel_mapping=node_parallel_mapping ) # init graph @@ -287,9 +290,17 @@ class Graph(BaseModel): if all(node_id in node_parallel_mapping for node_id in parallel_node_ids): parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]] + if not parent_parallel_id: + raise Exception(f"Parent parallel id not found for node ids {parallel_node_ids}") + + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + raise Exception(f"Parent parallel {parent_parallel_id} not found") + parallel = GraphParallel( start_from_node_id=start_node_id, - parent_parallel_id=parent_parallel_id + parent_parallel_id=parent_parallel.id, + parent_parallel_start_node_id=parent_parallel.start_from_node_id ) parallel_mapping[parallel.id] = parallel diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index e6ae6df559..c7d484ddf5 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -1,15 +1,24 @@ +from typing import Any + from pydantic import BaseModel, Field +from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState class GraphRuntimeState(BaseModel): variable_pool: VariablePool = Field(..., description="variable pool") - + """variable pool""" + start_at: float = Field(..., description="start time") + """start time""" total_tokens: int = 0 """total tokens""" + llm_usage: LLMUsage = LLMUsage.empty_usage() + """llm usage info""" + outputs: dict[str, Any] = {} + """outputs""" node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 3cc0354158..5a0d8a10e2 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -10,7 +10,11 @@ from uritemplate.variable import VariableValue from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeType, + UserFrom, +) from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -108,13 +112,29 @@ class GraphEngine: if isinstance(item, NodeRunFailedEvent): yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.') return + elif isinstance(item, NodeRunSucceededEvent): + if item.node_type == NodeType.END: + self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {}) + elif item.node_type == NodeType.ANSWER: + if "answer" not in self.graph_runtime_state.outputs: + self.graph_runtime_state.outputs["answer"] = "" + + self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "") + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip() except Exception as e: logger.exception(f"Graph run failed: {str(e)}") yield GraphRunFailedEvent(error=str(e)) return # trigger graph run success event - yield GraphRunSucceededEvent() + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) except GraphRunFailedError as e: yield GraphRunFailedEvent(error=e.error) return @@ -163,6 +183,7 @@ class GraphEngine: # init workflow run state node_instance = node_cls( # type: ignore + id=route_node_state.id, config=node_config, graph_init_params=self.init_params, graph=self.graph, @@ -192,6 +213,7 @@ class GraphEngine: route_node_state.failed_reason = str(e) yield NodeRunFailedEvent( error=str(e), + id=node_instance.id, node_id=next_node_id, node_type=node_type, node_data=node_instance.node_data, @@ -291,7 +313,7 @@ class GraphEngine: continue elif isinstance(event, ParallelBranchRunFailedEvent): - raise GraphRunFailedError(event.reason) + raise GraphRunFailedError(event.error) except queue.Empty: continue @@ -360,6 +382,7 @@ class GraphEngine: """ # trigger node run start event yield NodeRunStartedEvent( + id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, @@ -383,7 +406,8 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: yield NodeRunFailedEvent( - error=route_node_state.failed_reason, + error=route_node_state.failed_reason or 'Unknown error.', + id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, @@ -398,6 +422,10 @@ class GraphEngine: run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] ) + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + # append node output variables to variable pool if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): @@ -409,6 +437,7 @@ class GraphEngine: ) yield NodeRunSucceededEvent( + id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, @@ -420,6 +449,7 @@ class GraphEngine: break elif isinstance(item, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( + id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, @@ -431,6 +461,7 @@ class GraphEngine: ) elif isinstance(item, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( + id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, @@ -450,6 +481,7 @@ class GraphEngine: route_node_state.failed_reason = "Workflow stopped." yield NodeRunFailedEvent( error="Workflow stopped.", + id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index b9d0d05b98..7c5d2858e8 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -16,11 +16,13 @@ class BaseNode(ABC): _node_type: NodeType def __init__(self, + id: str, config: Mapping[str, Any], graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState, previous_node_id: Optional[str] = None) -> None: + self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id self.workflow_type = graph_init_params.workflow_type diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index d2c9578019..77d1d5efb0 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -7,7 +7,8 @@ class EndStreamGeneratorRouter: @classmethod def init(cls, node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str] ) -> EndStreamParam: """ Get stream generate routes. @@ -19,6 +20,10 @@ class EndStreamGeneratorRouter: if not node_config.get('data', {}).get('type') == NodeType.END.value: continue + # skip end node in parallel + if end_node_id in node_parallel_mapping: + continue + # get generate route for stream output stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config) end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index ce7c0010a5..c5190440bd 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,5 +1,6 @@ import logging from collections.abc import Generator +from datetime import datetime, timezone from typing import Any, cast from configs import dify_config @@ -123,10 +124,14 @@ class IterationNode(BaseNode): max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME ) + start_at = datetime.now(timezone.utc).replace(tzinfo=None) + yield IterationRunStartedEvent( + iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, iteration_node_data=self.node_data, + start_at=start_at, inputs=inputs, metadata={ "iterator_length": 1 @@ -135,6 +140,7 @@ class IterationNode(BaseNode): ) yield IterationRunNextEvent( + iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, iteration_node_data=self.node_data, @@ -186,6 +192,7 @@ class IterationNode(BaseNode): ) yield IterationRunNextEvent( + iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, iteration_node_data=self.node_data, @@ -197,9 +204,11 @@ class IterationNode(BaseNode): if isinstance(event, GraphRunFailedEvent): # iteration run failed yield IterationRunFailedEvent( + iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, iteration_node_data=self.node_data, + start_at=start_at, inputs=inputs, outputs={ "output": jsonable_encoder(outputs) @@ -222,9 +231,11 @@ class IterationNode(BaseNode): yield event yield IterationRunSucceededEvent( + iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, iteration_node_data=self.node_data, + start_at=start_at, inputs=inputs, outputs={ "output": jsonable_encoder(outputs) @@ -247,9 +258,11 @@ class IterationNode(BaseNode): # iteration run failed logger.exception("Iteration run failed") yield IterationRunFailedEvent( + iteration_id=self.id, iteration_node_id=self.node_id, iteration_node_type=self.node_type, iteration_node_data=self.node_data, + start_at=start_at, inputs=inputs, outputs={ "output": jsonable_encoder(outputs) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 196aa169a0..cdd7641b81 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,3 +1,4 @@ +import logging from typing import Any, cast from sqlalchemy import func @@ -20,6 +21,8 @@ from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus +logger = logging.getLogger(__name__) + default_retrieval_model = { 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, 'reranking_enable': False, @@ -67,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode): ) except Exception as e: - + logger.exception("Error when running knowledge retrieval node") return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index a381e55dae..4e5ecb42b4 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -168,7 +168,8 @@ class LLMNode(BaseNode): NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) ) diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 8f3ecf3793..eb28052f72 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -175,7 +175,8 @@ class ParameterExtractorNode(LLMNode): NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) def _invoke_llm(self, node_data_model: ModelConfig, diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index a21e111b95..dc757b7608 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -119,7 +119,8 @@ class QuestionClassifierNode(LLMNode): NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) except ValueError as e: @@ -131,7 +132,8 @@ class QuestionClassifierNode(LLMNode): NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, NodeRunMetadataKey.CURRENCY: usage.currency - } + }, + llm_usage=usage ) @classmethod diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ac88e7dea5..3c714069d3 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -117,10 +117,11 @@ class WorkflowEntry: graph_runtime_state=graph_engine.graph_runtime_state, event=event ) - yield event + yield event except GenerateTaskStoppedException: pass except Exception as e: + logger.exception("Unknown Error when workflow entry running") if callbacks: for callback in callbacks: callback.on_event( @@ -205,7 +206,7 @@ class WorkflowEntry: node_instance=node_instance ) - # run node TODO + # run node node_run_result = node_instance.run( variable_pool=variable_pool ) @@ -223,7 +224,7 @@ class WorkflowEntry: return node_instance, node_run_result @classmethod - def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]: + def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: """ Handle special values :param value: value @@ -232,7 +233,7 @@ class WorkflowEntry: if not value: return None - new_value = value.copy() + new_value = dict(value) if value else {} if isinstance(new_value, dict): for key, val in new_value.items(): if isinstance(val, FileVar): diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py new file mode 100644 index 0000000000..76ce6d0e09 --- /dev/null +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -0,0 +1,35 @@ +"""add node_execution_id into node_executions + +Revision ID: 675b5321501b +Revises: eeb2e349e6ac +Create Date: 2024-08-12 10:54:02.259331 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '675b5321501b' +down_revision = 'eeb2e349e6ac' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True)) + batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_id_idx') + batch_op.drop_column('node_execution_id') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index df2269cd0f..805c637994 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -548,6 +548,8 @@ class WorkflowNodeExecution(db.Model): 'triggered_from', 'workflow_run_id'), db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_id'), + db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id', + 'triggered_from', 'node_execution_id'), ) id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) @@ -558,6 +560,7 @@ class WorkflowNodeExecution(db.Model): workflow_run_id = db.Column(StringUUID) index = db.Column(db.Integer, nullable=False) predecessor_node_id = db.Column(db.String(255)) + node_execution_id = db.Column(db.String(255), nullable=True) node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 5bbbd2041d..fe89e5b6db 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,7 +11,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes.node_mapping import node_classes -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account @@ -209,11 +209,11 @@ class WorkflowService: raise ValueError('Workflow not initialized') # run draft workflow node - workflow_engine_manager = WorkflowEngineManager() + workflow_entry = WorkflowEntry() start_at = time.perf_counter() try: - node_instance, node_run_result = workflow_engine_manager.single_step_run( + node_instance, node_run_result = workflow_entry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs,