mirror of https://github.com/langgenius/dify.git
use callback to filter workflow stream output
This commit is contained in:
parent
6372183471
commit
79f0e894e9
|
|
@ -3,6 +3,7 @@ import time
|
|||
from typing import cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
|
|
@ -10,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
|
|
@ -93,7 +93,10 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
},
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
||||
callbacks=[WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
)
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
|
|
|
|||
|
|
@ -7,13 +7,15 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
|
||||
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
|
|
@ -51,13 +53,34 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
|
||||
def on_text_chunk(self, text: str) -> None:
|
||||
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
if node_id in self._streamable_node_ids:
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
|
||||
"""
|
||||
Fetch streamable node ids
|
||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
streamable_node_ids = []
|
||||
end_node_ids = []
|
||||
for node_config in graph.get('nodes'):
|
||||
if node_config.get('type') == NodeType.END.value:
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
|
||||
for edge_config in graph.get('edges'):
|
||||
if edge_config.get('target') in end_node_ids:
|
||||
streamable_node_ids.append(edge_config.get('source'))
|
||||
|
||||
return streamable_node_ids
|
||||
|
|
@ -4,13 +4,13 @@ from typing import cast
|
|||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
|
||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
|
@ -76,7 +76,10 @@ class WorkflowAppRunner:
|
|||
system_inputs={
|
||||
SystemVariable.FILES: files
|
||||
},
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
||||
callbacks=[WorkflowEventTriggerCallback(
|
||||
queue_manager=queue_manager,
|
||||
workflow=workflow
|
||||
)]
|
||||
)
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph)
|
||||
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Workflow run finished
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute finished
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
if node_id in self._streamable_node_ids:
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
|
||||
"""
|
||||
Fetch streamable node ids
|
||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
streamable_node_ids = []
|
||||
end_node_ids = []
|
||||
for node_config in graph.get('nodes'):
|
||||
if node_config.get('type') == NodeType.END.value:
|
||||
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
|
||||
for edge_config in graph.get('edges'):
|
||||
if edge_config.get('target') in end_node_ids:
|
||||
streamable_node_ids.append(edge_config.get('source'))
|
||||
|
||||
return streamable_node_ids
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
||||
class BaseWorkflowCallback:
|
||||
class BaseWorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
|
|
@ -33,7 +33,7 @@ class BaseWorkflowCallback:
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_text_chunk(self, text: str) -> None:
|
||||
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ class BaseNode:
|
|||
node_data: BaseNodeData
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
|
||||
stream_output_supported: bool = False
|
||||
callbacks: list[BaseWorkflowCallback]
|
||||
|
||||
def __init__(self, config: dict,
|
||||
|
|
@ -71,10 +70,12 @@ class BaseNode:
|
|||
:param text: chunk text
|
||||
:return:
|
||||
"""
|
||||
if self.stream_output_supported:
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_text_chunk(text)
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_node_text_chunk(
|
||||
node_id=self.node_id,
|
||||
text=text
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from models.workflow import (
|
|||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
|
|
@ -171,9 +170,6 @@ class WorkflowEngineManager:
|
|||
)
|
||||
)
|
||||
|
||||
# fetch predecessor node ids before end node (include: llm, direct answer)
|
||||
streamable_node_ids = self._fetch_streamable_node_ids(workflow, graph)
|
||||
|
||||
try:
|
||||
predecessor_node = None
|
||||
while True:
|
||||
|
|
@ -187,10 +183,6 @@ class WorkflowEngineManager:
|
|||
if not next_node:
|
||||
break
|
||||
|
||||
# check if node is streamable
|
||||
if next_node.node_id in streamable_node_ids:
|
||||
next_node.stream_output_supported = True
|
||||
|
||||
# max steps 30 reached
|
||||
if len(workflow_run_state.workflow_node_executions) > 30:
|
||||
raise ValueError('Max steps 30 reached.')
|
||||
|
|
@ -233,34 +225,6 @@ class WorkflowEngineManager:
|
|||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _fetch_streamable_node_ids(self, workflow: Workflow, graph: dict) -> list[str]:
|
||||
"""
|
||||
Fetch streamable node ids
|
||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||
|
||||
:param workflow: Workflow instance
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
workflow_type = WorkflowType.value_of(workflow.type)
|
||||
|
||||
streamable_node_ids = []
|
||||
end_node_ids = []
|
||||
for node_config in graph.get('nodes'):
|
||||
if node_config.get('type') == NodeType.END.value:
|
||||
if workflow_type == WorkflowType.WORKFLOW:
|
||||
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
else:
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
|
||||
for edge_config in graph.get('edges'):
|
||||
if edge_config.get('target') in end_node_ids:
|
||||
streamable_node_ids.append(edge_config.get('source'))
|
||||
|
||||
return streamable_node_ids
|
||||
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
|
|
|
|||
Loading…
Reference in New Issue