answer stream output support

This commit is contained in:
takatost 2024-03-14 20:49:53 +08:00
parent f35ae2355f
commit e6b8b13f2e
10 changed files with 413 additions and 90 deletions

View File

@ -2,7 +2,7 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
from pydantic import BaseModel, Extra
@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueMessageFileEvent,
@ -34,6 +35,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
@ -51,15 +54,26 @@ from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__)
class StreamGenerateRoute(BaseModel):
"""
StreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
current_route_position: int = 0
class TaskState(BaseModel):
"""
TaskState entity
"""
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
node_type: NodeType
start_at: float
class Config:
@ -77,9 +91,11 @@ class TaskState(BaseModel):
total_tokens: int = 0
total_steps: int = 0
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[StreamGenerateRoute] = None
class Config:
"""Configuration for this pydantic object."""
@ -122,6 +138,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
if stream:
self._stream_generate_routes = self._get_stream_generate_routes()
else:
self._stream_generate_routes = None
def process(self) -> Union[dict, Generator]:
"""
Process generate task pipeline.
@ -290,6 +311,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(data)
break
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
workflow_run_response = {
'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id,
@ -309,7 +335,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
}
yield self._yield_response(workflow_run_response)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
# response moderation
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
@ -390,6 +416,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
yield self._yield_response(response)
elif isinstance(event, QueueTextChunkEvent):
if not self._is_stream_out_support(
event=event
):
continue
delta_text = event.text
if delta_text is None:
continue
@ -467,20 +498,28 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=event.node_type,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
db.session.close()
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# stream outputs from start
self._generate_stream_outputs_when_node_start()
return workflow_node_execution
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent):
@ -508,8 +547,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
error=event.error
)
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
# stream outputs when node finished
self._generate_stream_outputs_when_node_finished()
db.session.close()
@ -517,7 +556,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_run = (db.session.query(WorkflowRun)
.filter(WorkflowRun.id == self._task_state.workflow_run_id).first())
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
@ -642,7 +682,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
QuotaExceededError: {
'code': 'provider_quota_exceeded',
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
"Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400
},
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
@ -660,10 +700,10 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
else:
logging.error(e)
data = {
'code': 'internal_server_error',
'code': 'internal_server_error',
'message': 'Internal Server Error, please contact support.',
'status': 500
}
}
return {
'event': 'error',
@ -730,3 +770,218 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
),
queue_manager=self._queue_manager
)
def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id)
if not start_node_id:
continue
stream_generate_routes[start_node_id] = StreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \
-> Optional[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edge = None
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edge = edge
break
if not ingoing_edge:
return None
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
return None
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER
]:
start_node_id = target_node_id
elif node_type == NodeType.START.value:
start_node_id = source_node_id
else:
start_node_id = self._get_answer_start_at_node_id(graph, source_node_id)
return start_node_id
def _generate_stream_outputs_when_node_start(self) -> None:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
for route_chunk in self._task_state.current_stream_generate_state.generate_route:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
self._queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.TASK_PIPELINE
)
time.sleep(0.01)
self._task_state.current_stream_generate_state.current_route_position += 1
else:
break
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> None:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
self._queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.TASK_PIPELINE
)
time.sleep(0.01)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
route_chunk_node_id = value_selector[0]
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key)
else:
value = value.get(key)
if value:
text = None
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, object): # TODO FILE
# convert file to markdown
text = f'![]({value.get("url")})'
pass
if text:
for token in text:
self._queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.TASK_PIPELINE
)
time.sleep(0.01)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True

View File

@ -20,7 +20,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self) -> None:
"""
@ -114,34 +113,16 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
if node_id in self._streamable_node_ids:
self._queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('data', {}).get('type') == NodeType.END.value:
end_node_ids.append(node_config.get('id'))
for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))
return streamable_node_ids
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)

View File

@ -3,12 +3,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueErrorEvent,
QueueMessage,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
)
@ -54,8 +53,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent):
| QueueAdvancedChatMessageEndEvent):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():

View File

@ -112,7 +112,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""

View File

@ -17,6 +17,7 @@ class QueueEvent(Enum):
AGENT_MESSAGE = "agent_message"
MESSAGE_REPLACE = "message_replace"
MESSAGE_END = "message_end"
ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
@ -53,6 +54,7 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
event = QueueEvent.TEXT_CHUNK
text: str
metadata: Optional[dict] = None
class QueueAgentMessageEvent(AppQueueEvent):
@ -92,7 +94,14 @@ class QueueMessageEndEvent(AppQueueEvent):
QueueMessageEndEvent entity
"""
event = QueueEvent.MESSAGE_END
llm_result: LLMResult
llm_result: Optional[LLMResult] = None
class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
"""
QueueAdvancedChatMessageEndEvent entity
"""
event = QueueEvent.ADVANCED_CHAT_MESSAGE_END
class QueueWorkflowStartedEvent(AppQueueEvent):

View File

@ -64,7 +64,7 @@ class BaseWorkflowCallback(ABC):
raise NotImplementedError
@abstractmethod
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""

View File

@ -4,7 +4,12 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from models.workflow import WorkflowNodeExecutionStatus
@ -22,6 +27,40 @@ class AnswerNode(BaseNode):
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)
answer = []
for part in generate_routes:
if part.type == "var":
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get_variable_value(
variable_selector=value_selector,
target_value_type=ValueType.STRING
)
answer_part = {
"type": "text",
"text": value
}
# TODO File
else:
part = cast(TextGenerateRouteChunk, part)
answer_part = {
"type": "text",
"text": part.text
}
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
answer[-1]["text"] += answer_part["text"]
else:
answer.append(answer_part)
if len(answer) == 1 and answer[0]["type"] == "text":
answer = answer[0]["text"]
# re-fetch variable values
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
@ -31,7 +70,39 @@ class AnswerNode(BaseNode):
variable_values[variable_selector.variable] = value
variable_keys = list(variable_values.keys())
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
outputs={
"answer": answer
}
)
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(node_data.answer)
@ -44,46 +115,24 @@ class AnswerNode(BaseNode):
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
split_template = [
{
"type": "var" if self._is_variable(part, variable_keys) else "text",
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
}
for part in template.split('Ω') if part
]
generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
answer = []
for part in split_template:
if part["type"] == "var":
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
answer_part = {
"type": "text",
"text": value
}
# TODO File
else:
answer_part = {
"type": "text",
"text": part["value"]
}
return generate_routes
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
answer[-1]["text"] += answer_part["text"]
else:
answer.append(answer_part)
if len(answer) == 1 and answer[0]["type"] == "text":
answer = answer[0]["text"]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
outputs={
"answer": answer
}
)
def _is_variable(self, part, variable_keys):
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys

View File

@ -1,3 +1,6 @@
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData):
"""
variables: list[VariableSelector] = []
answer: str
class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
type: str
class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: str = "var"
value_selector: list[str]
class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: str = "text"
text: str

View File

@ -86,17 +86,22 @@ class BaseNode(ABC):
self.node_run_result = result
return result
def publish_text_chunk(self, text: str) -> None:
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text
text=text,
metadata={
"node_type": self.node_type,
"value_selector": value_selector
}
)
@classmethod

View File

@ -169,7 +169,7 @@ class LLMNode(BaseNode):
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text)
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
if not model:
model = result.model