mirror of
https://github.com/langgenius/dify.git
synced 2026-04-24 00:59:19 +08:00
use callback to filter workflow stream output
This commit is contained in:
parent
6372183471
commit
79f0e894e9
@ -3,6 +3,7 @@ import time
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
|
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.apps.base_app_runner import AppRunner
|
from core.app.apps.base_app_runner import AppRunner
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.entities.node_entities import SystemVariable
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner):
|
|||||||
SystemVariable.FILES: files,
|
SystemVariable.FILES: files,
|
||||||
SystemVariable.CONVERSATION: conversation.id,
|
SystemVariable.CONVERSATION: conversation.id,
|
||||||
},
|
},
|
||||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
callbacks=[WorkflowEventTriggerCallback(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
workflow=workflow
|
||||||
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||||
|
|||||||
@ -7,13 +7,15 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||||
|
|
||||||
def __init__(self, queue_manager: AppQueueManager):
|
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
|
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
|
||||||
|
|
||||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||||
"""
|
"""
|
||||||
@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||||||
PublishFrom.APPLICATION_MANAGER
|
PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||||
def on_text_chunk(self, text: str) -> None:
|
|
||||||
"""
|
"""
|
||||||
Publish text chunk
|
Publish text chunk
|
||||||
"""
|
"""
|
||||||
self._queue_manager.publish(
|
if node_id in self._streamable_node_ids:
|
||||||
QueueTextChunkEvent(
|
self._queue_manager.publish(
|
||||||
text=text
|
QueueTextChunkEvent(
|
||||||
), PublishFrom.APPLICATION_MANAGER
|
text=text
|
||||||
)
|
), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
|
||||||
|
"""
|
||||||
|
Fetch streamable node ids
|
||||||
|
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||||
|
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||||
|
|
||||||
|
:param graph: workflow graph
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
streamable_node_ids = []
|
||||||
|
end_node_ids = []
|
||||||
|
for node_config in graph.get('nodes'):
|
||||||
|
if node_config.get('type') == NodeType.END.value:
|
||||||
|
end_node_ids.append(node_config.get('id'))
|
||||||
|
|
||||||
|
for edge_config in graph.get('edges'):
|
||||||
|
if edge_config.get('target') in end_node_ids:
|
||||||
|
streamable_node_ids.append(edge_config.get('source'))
|
||||||
|
|
||||||
|
return streamable_node_ids
|
||||||
@ -4,13 +4,13 @@ from typing import cast
|
|||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||||
|
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AppGenerateEntity,
|
AppGenerateEntity,
|
||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
|
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
|
||||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
|
||||||
from core.moderation.base import ModerationException
|
from core.moderation.base import ModerationException
|
||||||
from core.moderation.input_moderation import InputModeration
|
from core.moderation.input_moderation import InputModeration
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.entities.node_entities import SystemVariable
|
||||||
@ -76,7 +76,10 @@ class WorkflowAppRunner:
|
|||||||
system_inputs={
|
system_inputs={
|
||||||
SystemVariable.FILES: files
|
SystemVariable.FILES: files
|
||||||
},
|
},
|
||||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
callbacks=[WorkflowEventTriggerCallback(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
workflow=workflow
|
||||||
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||||
|
|||||||
@ -0,0 +1,87 @@
|
|||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
QueueNodeFinishedEvent,
|
||||||
|
QueueNodeStartedEvent,
|
||||||
|
QueueTextChunkEvent,
|
||||||
|
QueueWorkflowFinishedEvent,
|
||||||
|
QueueWorkflowStartedEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||||
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||||
|
|
||||||
|
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||||
|
self._queue_manager = queue_manager
|
||||||
|
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
|
||||||
|
|
||||||
|
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||||
|
"""
|
||||||
|
Workflow run started
|
||||||
|
"""
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
|
||||||
|
"""
|
||||||
|
Workflow run finished
|
||||||
|
"""
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Workflow node execute started
|
||||||
|
"""
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Workflow node execute finished
|
||||||
|
"""
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||||
|
"""
|
||||||
|
Publish text chunk
|
||||||
|
"""
|
||||||
|
if node_id in self._streamable_node_ids:
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueTextChunkEvent(
|
||||||
|
text=text
|
||||||
|
), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
|
||||||
|
"""
|
||||||
|
Fetch streamable node ids
|
||||||
|
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||||
|
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||||
|
|
||||||
|
:param graph: workflow graph
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
streamable_node_ids = []
|
||||||
|
end_node_ids = []
|
||||||
|
for node_config in graph.get('nodes'):
|
||||||
|
if node_config.get('type') == NodeType.END.value:
|
||||||
|
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
||||||
|
end_node_ids.append(node_config.get('id'))
|
||||||
|
|
||||||
|
for edge_config in graph.get('edges'):
|
||||||
|
if edge_config.get('target') in end_node_ids:
|
||||||
|
streamable_node_ids.append(edge_config.get('source'))
|
||||||
|
|
||||||
|
return streamable_node_ids
|
||||||
@ -1,9 +1,9 @@
|
|||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||||
|
|
||||||
|
|
||||||
class BaseWorkflowCallback:
|
class BaseWorkflowCallback(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||||
"""
|
"""
|
||||||
@ -33,7 +33,7 @@ class BaseWorkflowCallback:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_text_chunk(self, text: str) -> None:
|
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||||
"""
|
"""
|
||||||
Publish text chunk
|
Publish text chunk
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -16,7 +16,6 @@ class BaseNode:
|
|||||||
node_data: BaseNodeData
|
node_data: BaseNodeData
|
||||||
node_run_result: Optional[NodeRunResult] = None
|
node_run_result: Optional[NodeRunResult] = None
|
||||||
|
|
||||||
stream_output_supported: bool = False
|
|
||||||
callbacks: list[BaseWorkflowCallback]
|
callbacks: list[BaseWorkflowCallback]
|
||||||
|
|
||||||
def __init__(self, config: dict,
|
def __init__(self, config: dict,
|
||||||
@ -71,10 +70,12 @@ class BaseNode:
|
|||||||
:param text: chunk text
|
:param text: chunk text
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if self.stream_output_supported:
|
if self.callbacks:
|
||||||
if self.callbacks:
|
for callback in self.callbacks:
|
||||||
for callback in self.callbacks:
|
callback.on_node_text_chunk(
|
||||||
callback.on_text_chunk(text)
|
node_id=self.node_id,
|
||||||
|
text=text
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||||
|
|||||||
@ -32,7 +32,6 @@ from models.workflow import (
|
|||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
WorkflowRunTriggeredFrom,
|
WorkflowRunTriggeredFrom,
|
||||||
WorkflowType,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
node_classes = {
|
node_classes = {
|
||||||
@ -171,9 +170,6 @@ class WorkflowEngineManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# fetch predecessor node ids before end node (include: llm, direct answer)
|
|
||||||
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
predecessor_node = None
|
predecessor_node = None
|
||||||
while True:
|
while True:
|
||||||
@ -187,10 +183,6 @@ class WorkflowEngineManager:
|
|||||||
if not next_node:
|
if not next_node:
|
||||||
break
|
break
|
||||||
|
|
||||||
# check if node is streamable
|
|
||||||
if next_node.node_id in streamable_node_ids:
|
|
||||||
next_node.stream_output_supported = True
|
|
||||||
|
|
||||||
# max steps 30 reached
|
# max steps 30 reached
|
||||||
if len(workflow_run_state.workflow_node_executions) > 30:
|
if len(workflow_run_state.workflow_node_executions) > 30:
|
||||||
raise ValueError('Max steps 30 reached.')
|
raise ValueError('Max steps 30 reached.')
|
||||||
@ -233,34 +225,6 @@ class WorkflowEngineManager:
|
|||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
|
|
||||||
"""
|
|
||||||
Fetch streamable node ids
|
|
||||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
|
||||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
|
||||||
|
|
||||||
:param workflow: Workflow instance
|
|
||||||
:param graph: workflow graph
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
workflow_type = WorkflowType.value_of(workflow.type)
|
|
||||||
|
|
||||||
streamable_node_ids = []
|
|
||||||
end_node_ids = []
|
|
||||||
for node_config in graph.get('nodes'):
|
|
||||||
if node_config.get('type') == NodeType.END.value:
|
|
||||||
if workflow_type == WorkflowType.WORKFLOW:
|
|
||||||
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
|
||||||
end_node_ids.append(node_config.get('id'))
|
|
||||||
else:
|
|
||||||
end_node_ids.append(node_config.get('id'))
|
|
||||||
|
|
||||||
for edge_config in graph.get('edges'):
|
|
||||||
if edge_config.get('target') in end_node_ids:
|
|
||||||
streamable_node_ids.append(edge_config.get('source'))
|
|
||||||
|
|
||||||
return streamable_node_ids
|
|
||||||
|
|
||||||
def _init_workflow_run(self, workflow: Workflow,
|
def _init_workflow_run(self, workflow: Workflow,
|
||||||
triggered_from: WorkflowRunTriggeredFrom,
|
triggered_from: WorkflowRunTriggeredFrom,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user