mirror of https://github.com/langgenius/dify.git
answer stream output support
This commit is contained in:
parent
f35ae2355f
commit
e6b8b13f2e
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
|
|
@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageFileEvent,
|
||||
|
|
@ -34,6 +35,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
|||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
|
|
@ -51,15 +54,26 @@ from services.annotation_service import AppAnnotationService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
StreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str
|
||||
generate_route: list[GenerateRouteChunk]
|
||||
current_route_position: int = 0
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
|
||||
class NodeExecutionInfo(BaseModel):
|
||||
"""
|
||||
NodeExecutionInfo entity
|
||||
"""
|
||||
workflow_node_execution_id: str
|
||||
node_type: NodeType
|
||||
start_at: float
|
||||
|
||||
class Config:
|
||||
|
|
@ -77,9 +91,11 @@ class TaskState(BaseModel):
|
|||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
current_stream_generate_state: Optional[StreamGenerateRoute] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
|
|
@ -122,6 +138,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
if stream:
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
else:
|
||||
self._stream_generate_routes = None
|
||||
|
||||
def process(self) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
|
|
@ -290,6 +311,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
yield self._yield_response(data)
|
||||
break
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
workflow_run_response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
|
|
@ -309,7 +335,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
}
|
||||
|
||||
yield self._yield_response(workflow_run_response)
|
||||
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
|
@ -390,6 +416,11 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
|
@ -467,20 +498,28 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
|
||||
latest_node_execution_info = TaskState.NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=event.node_type,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
db.session.close()
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
|
||||
|
||||
# stream outputs from start
|
||||
self._generate_stream_outputs_when_node_start()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
|
|
@ -508,8 +547,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
error=event.error
|
||||
)
|
||||
|
||||
# remove running node execution info
|
||||
del self._task_state.running_node_execution_infos[event.node_id]
|
||||
# stream outputs when node finished
|
||||
self._generate_stream_outputs_when_node_finished()
|
||||
|
||||
db.session.close()
|
||||
|
||||
|
|
@ -517,7 +556,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
|
||||
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
|
||||
-> WorkflowRun:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_run = (db.session.query(WorkflowRun)
|
||||
.filter(WorkflowRun.id == self._task_state.workflow_run_id).first())
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
|
|
@ -642,7 +682,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
|
|
@ -660,10 +700,10 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
'event': 'error',
|
||||
|
|
@ -730,3 +770,218 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
),
|
||||
queue_manager=self._queue_manager
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
answer_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.ANSWER.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in answer_node_configs:
|
||||
# get generate route for stream output
|
||||
answer_node_id = node_config['id']
|
||||
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
|
||||
start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id)
|
||||
if not start_node_id:
|
||||
continue
|
||||
|
||||
stream_generate_routes[start_node_id] = StreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \
|
||||
-> Optional[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edge = None
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edge = edge
|
||||
break
|
||||
|
||||
if not ingoing_edge:
|
||||
return None
|
||||
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
return None
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
elif node_type == NodeType.START.value:
|
||||
start_node_id = source_node_id
|
||||
else:
|
||||
start_node_id = self._get_answer_start_at_node_id(graph, source_node_id)
|
||||
|
||||
return start_node_id
|
||||
|
||||
def _generate_stream_outputs_when_node_start(self) -> None:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return
|
||||
|
||||
for route_chunk in self._task_state.current_stream_generate_state.generate_route:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
for token in route_chunk.text:
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
time.sleep(0.01)
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self) -> None:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
for token in route_chunk.text:
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
route_chunk_node_id = value_selector[0]
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
break
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
# get route chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first()
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
# get value from outputs
|
||||
value = None
|
||||
for key in value_selector[1:]:
|
||||
if not value:
|
||||
value = outputs.get(key)
|
||||
else:
|
||||
value = value.get(key)
|
||||
|
||||
if value:
|
||||
text = None
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, object): # TODO FILE
|
||||
# convert file to markdown
|
||||
text = f'})'
|
||||
pass
|
||||
|
||||
if text:
|
||||
for token in text:
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
time.sleep(0.01)
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return True
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return True
|
||||
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
route_chunk = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position]
|
||||
|
||||
if route_chunk.type != 'var':
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
|
|
@ -114,34 +113,16 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
if node_id in self._streamable_node_ids:
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
|
||||
"""
|
||||
Fetch streamable node ids
|
||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
streamable_node_ids = []
|
||||
end_node_ids = []
|
||||
for node_config in graph.get('nodes'):
|
||||
if node_config.get('data', {}).get('type') == NodeType.END.value:
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
|
||||
for edge_config in graph.get('edges'):
|
||||
if edge_config.get('target') in end_node_ids:
|
||||
streamable_node_ids.append(edge_config.get('source'))
|
||||
|
||||
return streamable_node_ids
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -54,8 +53,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
|||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent):
|
||||
| QueueAdvancedChatMessageEndEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ class QueueEvent(Enum):
|
|||
AGENT_MESSAGE = "agent_message"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
MESSAGE_END = "message_end"
|
||||
ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_SUCCEEDED = "workflow_succeeded"
|
||||
WORKFLOW_FAILED = "workflow_failed"
|
||||
|
|
@ -53,6 +54,7 @@ class QueueTextChunkEvent(AppQueueEvent):
|
|||
"""
|
||||
event = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
|
|
@ -92,7 +94,14 @@ class QueueMessageEndEvent(AppQueueEvent):
|
|||
QueueMessageEndEvent entity
|
||||
"""
|
||||
event = QueueEvent.MESSAGE_END
|
||||
llm_result: LLMResult
|
||||
llm_result: Optional[LLMResult] = None
|
||||
|
||||
|
||||
class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAdvancedChatMessageEndEvent entity
|
||||
"""
|
||||
event = QueueEvent.ADVANCED_CHAT_MESSAGE_END
|
||||
|
||||
|
||||
class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class BaseWorkflowCallback(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_node_text_chunk(self, node_id: str, text: str) -> None:
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,12 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
|||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
|
@ -22,6 +27,40 @@ class AnswerNode(BaseNode):
|
|||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = []
|
||||
for part in generate_routes:
|
||||
if part.type == "var":
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=value_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": value
|
||||
}
|
||||
# TODO File
|
||||
else:
|
||||
part = cast(TextGenerateRouteChunk, part)
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": part.text
|
||||
}
|
||||
|
||||
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
|
||||
answer[-1]["text"] += answer_part["text"]
|
||||
else:
|
||||
answer.append(answer_part)
|
||||
|
||||
if len(answer) == 1 and answer[0]["type"] == "text":
|
||||
answer = answer[0]["text"]
|
||||
|
||||
# re-fetch variable values
|
||||
variable_values = {}
|
||||
for variable_selector in node_data.variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
|
|
@ -31,7 +70,39 @@ class AnswerNode(BaseNode):
|
|||
|
||||
variable_values[variable_selector.variable] = value
|
||||
|
||||
variable_keys = list(variable_values.keys())
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variable_values,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(node_data.answer)
|
||||
|
|
@ -44,46 +115,24 @@ class AnswerNode(BaseNode):
|
|||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
split_template = [
|
||||
{
|
||||
"type": "var" if self._is_variable(part, variable_keys) else "text",
|
||||
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
|
||||
}
|
||||
for part in template.split('Ω') if part
|
||||
]
|
||||
generate_routes = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
answer = []
|
||||
for part in split_template:
|
||||
if part["type"] == "var":
|
||||
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": value
|
||||
}
|
||||
# TODO File
|
||||
else:
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": part["value"]
|
||||
}
|
||||
return generate_routes
|
||||
|
||||
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
|
||||
answer[-1]["text"] += answer_part["text"]
|
||||
else:
|
||||
answer.append(answer_part)
|
||||
|
||||
if len(answer) == 1 and answer[0]["type"] == "text":
|
||||
answer = answer[0]["text"]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variable_values,
|
||||
outputs={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
|
||||
def _is_variable(self, part, variable_keys):
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
|
@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData):
|
|||
"""
|
||||
variables: list[VariableSelector] = []
|
||||
answer: str
|
||||
|
||||
|
||||
class GenerateRouteChunk(BaseModel):
|
||||
"""
|
||||
Generate Route Chunk.
|
||||
"""
|
||||
type: str
|
||||
|
||||
|
||||
class VarGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Var Generate Route Chunk.
|
||||
"""
|
||||
type: str = "var"
|
||||
value_selector: list[str]
|
||||
|
||||
|
||||
class TextGenerateRouteChunk(GenerateRouteChunk):
|
||||
"""
|
||||
Text Generate Route Chunk.
|
||||
"""
|
||||
type: str = "text"
|
||||
text: str
|
||||
|
|
|
|||
|
|
@ -86,17 +86,22 @@ class BaseNode(ABC):
|
|||
self.node_run_result = result
|
||||
return result
|
||||
|
||||
def publish_text_chunk(self, text: str) -> None:
|
||||
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
:param text: chunk text
|
||||
:param value_selector: value selector
|
||||
:return:
|
||||
"""
|
||||
if self.callbacks:
|
||||
for callback in self.callbacks:
|
||||
callback.on_node_text_chunk(
|
||||
node_id=self.node_id,
|
||||
text=text
|
||||
text=text,
|
||||
metadata={
|
||||
"node_type": self.node_type,
|
||||
"value_selector": value_selector
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ class LLMNode(BaseNode):
|
|||
text = result.delta.message.content
|
||||
full_text += text
|
||||
|
||||
self.publish_text_chunk(text=text)
|
||||
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
|
|
|||
Loading…
Reference in New Issue