From 79f0e894e97c2f71772d044b1caeaeb570804d9e Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 7 Mar 2024 09:55:29 +0800 Subject: [PATCH] use callback to filter workflow stream output --- api/core/app/apps/advanced_chat/app_runner.py | 7 +- .../workflow_event_trigger_callback.py | 41 +++++++-- api/core/app/apps/workflow/app_runner.py | 7 +- .../workflow_event_trigger_callback.py | 87 +++++++++++++++++++ .../callbacks/base_workflow_callback.py | 6 +- api/core/workflow/nodes/base_node.py | 11 +-- api/core/workflow/workflow_engine_manager.py | 36 -------- 7 files changed, 138 insertions(+), 57 deletions(-) rename api/core/{callback_handler => app/apps/advanced_chat}/workflow_event_trigger_callback.py (55%) create mode 100644 api/core/app/apps/workflow/workflow_event_trigger_callback.py diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 8fff8fc37e..077f0c2de0 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,7 @@ import time from typing import cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig +from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( @@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent -from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable from core.workflow.workflow_engine_manager import WorkflowEngineManager @@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner): SystemVariable.FILES: files, SystemVariable.CONVERSATION: conversation.id, }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/callback_handler/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py similarity index 55% rename from api/core/callback_handler/workflow_event_trigger_callback.py rename to api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index f8bad94252..44fb5905b0 100644 --- a/api/core/callback_handler/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -7,13 +7,15 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, ) from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from models.workflow import WorkflowNodeExecution, WorkflowRun +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun class WorkflowEventTriggerCallback(BaseWorkflowCallback): - def __init__(self, queue_manager: AppQueueManager): + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ @@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): PublishFrom.APPLICATION_MANAGER ) - - def on_text_chunk(self, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str) -> None: """ Publish text chunk """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) + if node_id in self._streamable_node_ids: + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER + ) + + def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param graph: workflow graph + :return: + """ + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index e675026e41..132282ffe3 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -4,13 +4,13 @@ from typing import cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfig +from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.app.entities.app_invoke_entities import ( AppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent -from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback from core.moderation.base import ModerationException from core.moderation.input_moderation import InputModeration from core.workflow.entities.node_entities import SystemVariable @@ -76,7 +76,10 @@ class WorkflowAppRunner: system_inputs={ SystemVariable.FILES: files }, - callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)] + callbacks=[WorkflowEventTriggerCallback( + queue_manager=queue_manager, + workflow=workflow + )] ) def handle_input_moderation(self, queue_manager: AppQueueManager, diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py new file mode 100644 index 0000000000..57775f2cce --- /dev/null +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -0,0 +1,87 @@ +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import ( + QueueNodeFinishedEvent, + QueueNodeStartedEvent, + QueueTextChunkEvent, + QueueWorkflowFinishedEvent, + QueueWorkflowStartedEvent, +) +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.node_entities import NodeType +from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun + + +class WorkflowEventTriggerCallback(BaseWorkflowCallback): + + def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): + self._queue_manager = queue_manager + self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph) + + def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run started + """ + self._queue_manager.publish( + QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None: + """ + Workflow run finished + """ + self._queue_manager.publish( + QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute started + """ + self._queue_manager.publish( + QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None: + """ + Workflow node execute finished + """ + self._queue_manager.publish( + QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id), + PublishFrom.APPLICATION_MANAGER + ) + + def on_node_text_chunk(self, node_id: str, text: str) -> None: + """ + Publish text chunk + """ + if node_id in self._streamable_node_ids: + self._queue_manager.publish( + QueueTextChunkEvent( + text=text + ), PublishFrom.APPLICATION_MANAGER + ) + + def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: + """ + Fetch streamable node ids + When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output + When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output + + :param graph: workflow graph + :return: + """ + streamable_node_ids = [] + end_node_ids = [] + for node_config in graph.get('nodes'): + if node_config.get('type') == NodeType.END.value: + if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': + end_node_ids.append(node_config.get('id')) + + for edge_config in graph.get('edges'): + if edge_config.get('target') in end_node_ids: + streamable_node_ids.append(edge_config.get('source')) + + return streamable_node_ids diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 3425b2b03c..3866bf2c15 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,9 +1,9 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from models.workflow import WorkflowNodeExecution, WorkflowRun -class BaseWorkflowCallback: +class BaseWorkflowCallback(ABC): @abstractmethod def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None: """ @@ -33,7 +33,7 @@ class BaseWorkflowCallback: raise NotImplementedError @abstractmethod - def on_text_chunk(self, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str) -> None: """ Publish text chunk """ diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index efffdfae1a..1ff05f9f4e 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -16,7 +16,6 @@ class BaseNode: node_data: BaseNodeData node_run_result: Optional[NodeRunResult] = None - stream_output_supported: bool = False callbacks: list[BaseWorkflowCallback] def __init__(self, config: dict, @@ -71,10 +70,12 @@ class BaseNode: :param text: chunk text :return: """ - if self.stream_output_supported: - if self.callbacks: - for callback in self.callbacks: - callback.on_text_chunk(text) + if self.callbacks: + for callback in self.callbacks: + callback.on_node_text_chunk( + node_id=self.node_id, + text=text + ) @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 908b684930..4d881d3d04 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -32,7 +32,6 @@ from models.workflow import ( WorkflowRun, WorkflowRunStatus, WorkflowRunTriggeredFrom, - WorkflowType, ) node_classes = { @@ -171,9 +170,6 @@ class WorkflowEngineManager: ) ) - # fetch predecessor node ids before end node (include: llm, direct answer) - streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph) - try: predecessor_node = None while True: @@ -187,10 +183,6 @@ class WorkflowEngineManager: if not next_node: break - # check if node is streamable - if next_node.node_id in streamable_node_ids: - next_node.stream_output_supported = True - # max steps 30 reached if len(workflow_run_state.workflow_node_executions) > 30: raise ValueError('Max steps 30 reached.') @@ -233,34 +225,6 @@ class WorkflowEngineManager: callbacks=callbacks ) - def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param workflow: Workflow instance - :param graph: workflow graph - :return: - """ - workflow_type = WorkflowType.value_of(workflow.type) - - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('type') == NodeType.END.value: - if workflow_type == WorkflowType.WORKFLOW: - if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text': - end_node_ids.append(node_config.get('id')) - else: - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids - def _init_workflow_run(self, workflow: Workflow, triggered_from: WorkflowRunTriggeredFrom, user: Union[Account, EndUser],