From dad1a967eeb1bf7625d20980bb6c4467cd297c46 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 20 Jul 2024 00:49:46 +0800 Subject: [PATCH] finished answer stream output --- .../advanced_chat/generate_task_pipeline.py | 302 +----------------- .../workflow/graph_engine/entities/event.py | 3 +- .../workflow/graph_engine/entities/graph.py | 31 +- .../entities/graph_runtime_state.py | 10 +- .../workflow/graph_engine/graph_engine.py | 13 +- api/core/workflow/nodes/answer/answer_node.py | 4 +- .../answer/answer_stream_generate_router.py | 203 ++++++++++++ .../answer/answer_stream_output_manager.py | 160 ---------- .../nodes/answer/answer_stream_processor.py | 286 +++++++++++++++++ api/core/workflow/nodes/answer/entities.py | 20 +- .../core/workflow/nodes/answer/__init__.py | 0 .../nodes/{ => answer}/test_answer.py | 65 +++- .../test_answer_stream_generate_router.py | 125 ++++++++ .../answer/test_answer_stream_processor.py | 206 ++++++++++++ .../core/workflow/nodes/test_if_else.py | 83 +++-- 15 files changed, 989 insertions(+), 522 deletions(-) create mode 100644 api/core/workflow/nodes/answer/answer_stream_generate_router.py delete mode 100644 api/core/workflow/nodes/answer/answer_stream_output_manager.py create mode 100644 api/core/workflow/nodes/answer/answer_stream_processor.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/__init__.py rename api/tests/unit_tests/core/workflow/nodes/{ => answer}/test_answer.py (54%) create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 4b089f033f..b332ac7af8 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,7 +2,7 @@ import json import logging import time from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk @@ -33,7 +33,6 @@ from core.app.entities.task_entities import ( AdvancedChatTaskState, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, - ChatflowStreamGenerateRoute, ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, @@ -43,20 +42,16 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage -from core.file.file_obj import FileVar from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.node_entities import NodeType, SystemVariable -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account from models.model import Conversation, EndUser, Message from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowRunStatus, ) @@ -430,102 +425,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc **extras ) - def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: - """ - Get stream generate routes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - answer_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.ANSWER.value - ] - - # parse stream output node value selectors of answer nodes - stream_generate_routes = {} - for node_config in answer_node_configs: - # get generate route for stream output - answer_node_id = node_config['id'] - generate_route = AnswerNode.extract_generate_route_selectors(node_config) - start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute( - answer_node_id=answer_node_id, - generate_route=generate_route - ) - - return stream_generate_routes - - def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get answer start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - # check if it's the first node in the iteration - target_node = next((node for node in nodes if node.get('id') == target_node_id), None) - if not target_node: - return [] - - node_iteration_id = target_node.get('data', {}).get('iteration_id') - # get iteration start node id - for node in nodes: - if node.get('id') == node_iteration_id: - if node.get('data', {}).get('start_node_id') == target_node_id: - return [target_node_id] - - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, - NodeType.LOOP.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: """ Get iteration nested relations. @@ -546,205 +445,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc ] for iteration_id in iteration_ids } - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position: - ] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text) - if should_direct_answer: - continue - - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - break - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: - """ - Generate stream outputs. - :return: - """ - if not self._task_state.current_stream_generate_state: - return - - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position:] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - value = None - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - if not value_selector: - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - route_chunk_node_id = value_selector[0] - - if route_chunk_node_id == 'sys': - # system variable - value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1])) - elif route_chunk_node_id in self._iteration_nested_relations: - # it's a iteration variable - if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: - continue - iteration_state = self._iteration_state.current_iterations[route_chunk_node_id] - iterator = iteration_state.inputs - if not iterator: - continue - iterator_selector = iterator.get('iterator_selector', []) - if value_selector[1] == 'index': - value = iteration_state.current_index - elif value_selector[1] == 'item': - value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len( - iterator_selector - ) else None - else: - # check chunk node id is before current node id or equal to current node id - if route_chunk_node_id not in self._task_state.ran_node_execution_infos: - break - - latest_node_execution_info = self._task_state.latest_node_execution_info - - # get route chunk node execution info - route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] - if (route_chunk_node_execution_info.node_type == NodeType.LLM - and latest_node_execution_info.node_type == NodeType.LLM): - # only LLM support chunk stream output - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - # get route chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id - ).first() - - outputs = route_chunk_node_execution.outputs_dict - - # get value from outputs - value = None - for key in value_selector[1:]: - if not value: - value = outputs.get(key) if outputs else None - else: - value = value.get(key) - - if value is not None: - text = '' - if isinstance(value, str | int | float): - text = str(value) - elif isinstance(value, FileVar): - # convert file to markdown - text = value.to_markdown() - elif isinstance(value, dict): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - if file_vars: - file_var = file_vars[0] - try: - file_var_obj = FileVar(**file_var) - - # convert file to markdown - text = file_var_obj.to_markdown() - except Exception as e: - logger.error(f'Error creating file var: {e}') - - if not text: - # other types - text = json.dumps(value, ensure_ascii=False) - elif isinstance(value, list): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - for file_var in file_vars: - try: - file_var_obj = FileVar(**file_var) - except Exception as e: - logger.error(f'Error creating file var: {e}') - continue - - # convert file to markdown - text = file_var_obj.to_markdown() + ' ' - - text = text.strip() - - if not text and value: - # other types - text = json.dumps(value, ensure_ascii=False) - - if text: - self._task_state.answer += text - yield self._message_to_stream_response(text, self._message.id) - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return True - - if 'node_id' not in event.metadata: - return True - - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - route_chunk = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position] - - if route_chunk.type != 'var': - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - - # check chunk node id is before current node id or equal to current node id - if value_selector != stream_output_value_selector: - return False - - return True - def _handle_output_moderation_chunk(self, text: str) -> bool: """ Handle output moderation chunk. diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 5147982e1b..071ad164f8 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -50,7 +50,8 @@ class NodeRunStartedEvent(BaseNodeEvent): class NodeRunStreamChunkEvent(BaseNodeEvent): chunk_content: str = Field(..., description="chunk content") - from_variable_selector: list[str] = Field(..., description="from variable selector") + from_variable_selector: Optional[list[str]] = None + """from variable selector""" class NodeRunRetrieverResourceEvent(BaseNodeEvent): diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 2d52eab5f4..e12f7cec47 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -5,21 +5,24 @@ from pydantic import BaseModel, Field from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.run_condition import RunCondition -from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute class GraphEdge(BaseModel): source_node_id: str = Field(..., description="source node id") target_node_id: str = Field(..., description="target node id") - run_condition: Optional[RunCondition] = Field(None, description="run condition") + run_condition: Optional[RunCondition] = None + """run condition""" class GraphParallel(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") start_from_node_id: str = Field(..., description="start from node id") - parent_parallel_id: Optional[str] = Field(None, description="parent parallel id") - end_to_node_id: Optional[str] = Field(None, description="end to node id") + parent_parallel_id: Optional[str] = None + """parent parallel id""" + end_to_node_id: Optional[str] = None + """end to node id""" class Graph(BaseModel): @@ -33,6 +36,10 @@ class Graph(BaseModel): default_factory=dict, description="graph edge mapping (source node id: edges)" ) + reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, + description="reverse graph edge mapping (target node id: edges)" + ) parallel_mapping: dict[str, GraphParallel] = Field( default_factory=dict, description="graph parallel mapping (parallel id: parallel)" @@ -41,8 +48,8 @@ class Graph(BaseModel): default_factory=dict, description="graph node parallel mapping (node id: parallel id)" ) - answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field( - default_factory=dict, + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field( + ..., description="answer stream generate routes" ) @@ -66,6 +73,7 @@ class Graph(BaseModel): # reorganize edges mapping edge_mapping: dict[str, list[GraphEdge]] = {} + reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() for edge_config in edge_configs: source_node_id = edge_config.get('source') @@ -79,6 +87,9 @@ class Graph(BaseModel): if not target_node_id: continue + if target_node_id not in reverse_edge_mapping: + reverse_edge_mapping[target_node_id] = [] + target_edge_ids.add(target_node_id) # parse run condition @@ -91,11 +102,12 @@ class Graph(BaseModel): graph_edge = GraphEdge( source_node_id=source_node_id, - target_node_id=edge_config.get('target'), + target_node_id=target_node_id, run_condition=run_condition ) edge_mapping[source_node_id].append(graph_edge) + reverse_edge_mapping[target_node_id].append(graph_edge) # node configs node_configs = graph_config.get('nodes') @@ -149,9 +161,9 @@ class Graph(BaseModel): ) # init answer stream generate routes - answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes( + answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( node_id_config_mapping=node_id_config_mapping, - edge_mapping=edge_mapping + reverse_edge_mapping=reverse_edge_mapping ) # init graph @@ -160,6 +172,7 @@ class Graph(BaseModel): node_ids=node_ids, node_id_config_mapping=node_id_config_mapping, edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping, answer_stream_generate_routes=answer_stream_generate_routes diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index 4d17d277e3..e6ae6df559 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -8,7 +8,11 @@ class GraphRuntimeState(BaseModel): variable_pool: VariablePool = Field(..., description="variable pool") start_at: float = Field(..., description="start time") - total_tokens: int = Field(0, description="total tokens") - node_run_steps: int = Field(0, description="node run steps") + total_tokens: int = 0 + """total tokens""" - node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state") + node_run_steps: int = 0 + """node run steps""" + + node_run_state: RuntimeRouteState = RuntimeRouteState() + """node run state""" diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 6289677556..80e2cf6899 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -28,6 +28,9 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor + +# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import node_classes from extensions.ext_database import db @@ -81,6 +84,10 @@ class GraphEngine: try: # run graph generator = self._run(start_node_id=self.graph.root_node_id) + if self.init_params.workflow_type == WorkflowType.CHAT: + answer_stream_processor = AnswerStreamProcessor(self.graph) + generator = answer_stream_processor.process(generator) + for item in generator: yield item if isinstance(item, NodeRunFailedEvent): @@ -314,8 +321,6 @@ class GraphEngine: db.session.close() - # TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node - self.graph_runtime_state.node_run_steps += 1 try: @@ -335,7 +340,7 @@ class GraphEngine: if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): # plus state total_tokens self.graph_runtime_state.total_tokens += int( - run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] ) # append node output variables to variable pool @@ -397,7 +402,7 @@ class GraphEngine: self.graph_runtime_state.variable_pool.append_variable( node_id=node_id, variable_key_list=variable_key_list, - value=variable_value + value=variable_value # type: ignore[arg-type] ) # if variable_value is a dict, then recursively append variables diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index a667c9ab73..b6c67d585c 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -4,7 +4,7 @@ from typing import cast from core.file.file_obj import FileVar from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, GenerateRouteChunk, @@ -29,7 +29,7 @@ class AnswerNode(BaseNode): node_data = cast(AnswerNodeData, node_data) # generate routes - generate_routes = AnswerStreamOutputManager.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data) answer = '' for part in generate_routes: diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py new file mode 100644 index 0000000000..6e7dcb7ede --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -0,0 +1,203 @@ + +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + AnswerStreamGenerateRoute, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class AnswerStreamGeneratorRouter: + + @classmethod + def init(cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selectors of answer nodes + answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} + for answer_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: + continue + + # get generate route for stream output + generate_route = cls._extract_generate_route_selectors(node_config) + answer_generate_route[answer_node_id] = generate_route + + # fetch answer dependencies + answer_node_ids = list(answer_generate_route.keys()) + answer_dependencies = cls._fetch_answers_dependencies( + answer_node_ids=answer_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping + ) + + return AnswerStreamGenerateRoute( + answer_generate_route=answer_generate_route, + answer_dependencies=answer_dependencies + ) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in variable_selectors + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + generate_routes: list[GenerateRouteChunk] = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) + + return generate_routes + + @classmethod + def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = AnswerNodeData(**config.get("data", {})) + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def _is_variable(cls, part, variable_keys): + cleaned_part = part.replace('{{', '').replace('}}', '') + return part.startswith('{{') and cleaned_part in variable_keys + + @classmethod + def _fetch_answers_dependencies(cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict] + ) -> dict[str, list[str]]: + """ + Fetch answer dependencies + :param answer_node_ids: answer node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + answer_dependencies: dict[str, list[str]] = {} + for answer_node_id in answer_node_ids: + if answer_dependencies.get(answer_node_id) is None: + answer_dependencies[answer_node_id] = [] + + cls._recursive_fetch_answer_dependencies( + current_node_id=answer_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies + ) + + return answer_dependencies + + @classmethod + def _recursive_fetch_answer_dependencies(cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]] + ) -> None: + """ + Recursive fetch answer dependencies + :param current_node_id: current node id + :param answer_node_id: answer node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param answer_dependencies: answer dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + if source_node_type in ( + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + ): + answer_dependencies[answer_node_id].append(source_node_id) + else: + cls._recursive_fetch_answer_dependencies( + current_node_id=source_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies + ) + + @classmethod + def _fetch_answer_dependencies(cls, + current_node_id: str, + answer_node_id: str, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]] + ) -> None: + """ + Fetch answer dependencies + :param current_node_id: current node id + :param answer_node_id: answer node id + :param answer_node_ids: answer node ids + :param reverse_edge_mapping: reverse edge mapping + :param answer_dependencies: answer dependencies + :return: + """ + for edge in reverse_edge_mapping.get(current_node_id, []): + source_node_id = edge.source_node_id + if source_node_id == answer_node_id: + continue + + if source_node_id in answer_node_ids: + # is answer node + answer_dependencies[answer_node_id].append(source_node_id) + else: + cls._fetch_answer_dependencies( + current_node_id=source_node_id, + answer_node_id=answer_node_id, + answer_node_ids=answer_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies + ) diff --git a/api/core/workflow/nodes/answer/answer_stream_output_manager.py b/api/core/workflow/nodes/answer/answer_stream_output_manager.py deleted file mode 100644 index ff2d955cd2..0000000000 --- a/api/core/workflow/nodes/answer/answer_stream_output_manager.py +++ /dev/null @@ -1,160 +0,0 @@ -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.entities import ( - AnswerNodeData, - AnswerStreamGenerateRoute, - GenerateRouteChunk, - TextGenerateRouteChunk, - VarGenerateRouteChunk, -) -from core.workflow.utils.variable_template_parser import VariableTemplateParser - - -class AnswerStreamOutputManager: - - @classmethod - def init_stream_generate_routes(cls, - node_id_config_mapping: dict[str, dict], - edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] - ) -> dict[str, AnswerStreamGenerateRoute]: - """ - Get stream generate routes. - :return: - """ - # parse stream output node value selectors of answer nodes - stream_generate_routes = {} - for node_id, node_config in node_id_config_mapping.items(): - if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: - continue - - # get generate route for stream output - generate_route = cls._extract_generate_route_selectors(node_config) - streaming_node_ids = cls._get_streaming_node_ids( - target_node_id=node_id, - node_id_config_mapping=node_id_config_mapping, - edge_mapping=edge_mapping - ) - - if not streaming_node_ids: - continue - - for streaming_node_id in streaming_node_ids: - stream_generate_routes[streaming_node_id] = AnswerStreamGenerateRoute( - answer_node_id=node_id, - generate_route=generate_route - ) - - return stream_generate_routes - - @classmethod - def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: - """ - Extract generate route from node data - :param node_data: node data object - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors - } - - variable_keys = list(value_selector_mapping.keys()) - - # format answer template - template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - generate_routes: list[GenerateRouteChunk] = [] - for part in template.split('Ω'): - if part: - if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') - value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) - else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) - - return generate_routes - - @classmethod - def _get_streaming_node_ids(cls, - target_node_id: str, - node_id_config_mapping: dict[str, dict], - edge_mapping: dict[str, list["GraphEdge"]]) -> list[str]: # type: ignore[name-defined] - """ - Get answer stream node IDs. - :param target_node_id: target node ID - :return: - """ - # fetch all ingoing edges from source node - ingoing_graph_edges = [] - for graph_edges in edge_mapping.values(): - for graph_edge in graph_edges: - if graph_edge.target_node_id == target_node_id: - ingoing_graph_edges.append(graph_edge) - - if not ingoing_graph_edges: - return [] - - streaming_node_ids = [] - for ingoing_graph_edge in ingoing_graph_edges: - source_node_id = ingoing_graph_edge.source_node_id - source_node = node_id_config_mapping.get(source_node_id) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - - if node_type in [ - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, - NodeType.LOOP.value - ]: - # Current node (answer nodes / multi-branch nodes / iteration nodes) cannot be stream node. - streaming_node_ids.append(target_node_id) - elif node_type == NodeType.START.value: - # Current node is START node, can be stream node. - streaming_node_ids.append(source_node_id) - else: - # Find the stream node forward. - sub_streaming_node_ids = cls._get_streaming_node_ids( - target_node_id=source_node_id, - node_id_config_mapping=node_id_config_mapping, - edge_mapping=edge_mapping - ) - - if sub_streaming_node_ids: - streaming_node_ids.extend(sub_streaming_node_ids) - - return streaming_node_ids - - @classmethod - def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: - """ - Extract generate route selectors - :param config: node config - :return: - """ - node_data = AnswerNodeData(**config.get("data", {})) - return cls.extract_generate_route_from_node_data(node_data) - - @classmethod - def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py new file mode 100644 index 0000000000..a28a7786f7 --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -0,0 +1,286 @@ +import json +import logging +from collections.abc import Generator +from typing import Optional, cast + +from core.file.file_obj import FileVar +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk + +logger = logging.getLogger(__name__) + + +class AnswerStreamProcessor: + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + self.graph = graph + self.variable_pool = variable_pool + self.generate_routes = graph.answer_stream_generate_routes + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.rest_node_ids = graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStreamChunkEvent): + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) + self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] = stream_out_answer_node_ids + + for _ in stream_out_answer_node_ids: + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[answer_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + finished_node_id = event.route_node_state.node_id + + if finished_node_id not in self.rest_node_ids: + return + + # remove finished node id + self.rest_node_ids.remove(finished_node_id) + + run_result = event.route_node_state.node_run_result + if not run_result: + return + + if run_result.edge_source_handle: + reachable_node_ids = [] + unreachable_first_node_ids = [] + for edge in self.graph.edge_mapping[finished_node_id]: + if (edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify): + reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + continue + else: + unreachable_first_node_ids.append(edge.target_node_id) + + for node_id in unreachable_first_node_ids: + self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + node_ids = [] + for edge in self.graph.edge_mapping[node_id]: + node_ids.append(edge.target_node_id) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + return node_ids + + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + """ + remove target node ids until merge + """ + self.rest_node_ids.remove(node_id) + for edge in self.graph.edge_mapping[node_id]: + if edge.target_node_id in reachable_node_ids: + continue + + self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) + + def _generate_stream_outputs_when_node_finished(self, + event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for answer_node_id, position in self.route_position.items(): + # all depends on answer node id not in rest node ids + if not all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + continue + + route_position = self.route_position[answer_node_id] + route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + yield NodeRunStreamChunkEvent( + chunk_content=route_chunk.text, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + ) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + if not value_selector: + break + + value = self.variable_pool.get_variable_value( + variable_selector=value_selector + ) + + if value is None: + break + + text = '' + if isinstance(value, str | int | float): + text = str(value) + elif isinstance(value, FileVar): + # convert file to markdown + text = value.to_markdown() + elif isinstance(value, dict): + # handle files + file_vars = self._fetch_files_from_variable_value(value) + if file_vars: + file_var = file_vars[0] + try: + file_var_obj = FileVar(**file_var) + + # convert file to markdown + text = file_var_obj.to_markdown() + except Exception as e: + logger.error(f'Error creating file var: {e}') + + if not text: + # other types + text = json.dumps(value, ensure_ascii=False) + elif isinstance(value, list): + # handle files + file_vars = self._fetch_files_from_variable_value(value) + for file_var in file_vars: + try: + file_var_obj = FileVar(**file_var) + except Exception as e: + logger.error(f'Error creating file var: {e}') + continue + + # convert file to markdown + text = file_var_obj.to_markdown() + ' ' + + text = text.strip() + + if not text and value: + # other types + text = json.dumps(value, ensure_ascii=False) + + if text: + yield NodeRunStreamChunkEvent( + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + ) + + self.route_position[answer_node_id] += 1 + + def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_answer_node_ids = [] + for answer_node_id, position in self.route_position.items(): + if answer_node_id not in self.rest_node_ids: + continue + + # all depends on answer node id not in rest node ids + if all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + route_position = self.route_position[answer_node_id] + if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + continue + + route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] + + if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: + continue + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_answer_node_ids.append(answer_node_id) + + return stream_out_answer_node_ids + + @classmethod + def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + if not value: + return [] + + files = [] + if isinstance(value, list): + for item in value: + file_var = cls._get_file_var_from_value(item) + if file_var: + files.append(file_var) + elif isinstance(value, dict): + file_var = cls._get_file_var_from_value(value) + if file_var: + files.append(file_var) + + return files + + @classmethod + def _get_file_var_from_value(self, value: dict | list) -> Optional[dict]: + """ + Get file var from value + :param value: variable value + :return: + """ + if not value: + return None + + if isinstance(value, dict): + if '__variant' in value and value['__variant'] == FileVar.__name__: + return value + elif isinstance(value, FileVar): + return value.to_dict() + + return None diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 6f29af2027..620c2c426b 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -42,11 +42,21 @@ class TextGenerateRouteChunk(GenerateRouteChunk): text: str = Field(..., description="text") +class AnswerNodeDoubleLink(BaseModel): + node_id: str = Field(..., description="node id") + source_node_ids: list[str] = Field(..., description="source node ids") + target_node_ids: list[str] = Field(..., description="target node ids") + + class AnswerStreamGenerateRoute(BaseModel): """ - ChatflowStreamGenerateRoute entity + AnswerStreamGenerateRoute entity """ - answer_node_id: str = Field(..., description="answer node ID") - generate_route: list[GenerateRouteChunk] = Field(..., description="answer stream generate route") - current_route_position: int = 0 - """current generate route position""" + answer_dependencies: dict[str, list[str]] = Field( + ..., + description="answer dependencies (answer node id -> dependent answer node ids)" + ) + answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( + ..., + description="answer generate route (answer node id -> generate route chunks)" + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py b/api/tests/unit_tests/core/workflow/nodes/answer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py similarity index 54% rename from api/tests/unit_tests/core/workflow/nodes/test_answer.py rename to api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index c4095f8c3a..376d3f6521 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -1,29 +1,55 @@ +import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_answer(): - node = AnswerNode( + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm" + }, + ] + } + + graph = Graph.init( + graph_config=graph_config + ) + + init_params = GraphInitParams( tenant_id='1', app_id='1', + workflow_type=WorkflowType.WORKFLOW, workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, - config={ - 'id': 'answer', - 'data': { - 'title': '123', - 'type': 'answer', - 'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.' - } - } + call_depth=0 ) # construct variable pool @@ -34,11 +60,28 @@ def test_execute_answer(): pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny') pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.') + node = AnswerNode( + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + start_at=time.perf_counter() + ), + config={ + 'id': 'answer', + 'data': { + 'title': '123', + 'type': 'answer', + 'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.' + } + } + ) + # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py new file mode 100644 index 0000000000..7a48c3548b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_generate_router.py @@ -0,0 +1,125 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "llm", + }, + "id": "llm4" + }, + { + "data": { + "type": "llm", + }, + "id": "llm5" + }, + { + "data": { + "type": "answer", + "title": "answer", + "answer": "1{{#llm2.text#}}2" + }, + "id": "answer", + }, + { + "data": { + "type": "answer", + "title": "answer2", + "answer": "1{{#llm3.text#}}2" + }, + "id": "answer2", + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + answer_stream_generate_route = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=graph.node_id_config_mapping, + reverse_edge_mapping=graph.reverse_edge_mapping + ) + + assert answer_stream_generate_route.answer_dependencies['answer'] == ['answer2'] + assert answer_stream_generate_route.answer_dependencies['answer2'] == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py new file mode 100644 index 0000000000..c0c95d7bb8 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -0,0 +1,206 @@ +from collections.abc import Generator +from datetime import datetime, timezone + +from core.workflow.entities.node_entities import SystemVariable +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor + + +def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + if next_node_id == 'start': + yield from _publish_events(graph, next_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _publish_events(graph, edge.target_node_id) + + for edge in graph.edge_mapping.get(next_node_id, []): + yield from _recursive_process(graph, edge.target_node_id) + + +def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: + route_node_state = RouteNodeState( + node_id=next_node_id, + start_at=datetime.now(timezone.utc).replace(tzinfo=None) + ) + + yield NodeRunStartedEvent( + route_node_state=route_node_state, + parallel_id=graph.node_parallel_mapping.get(next_node_id), + ) + + if 'llm' in next_node_id: + length = int(next_node_id[-1]) + for i in range(0, length): + yield NodeRunStreamChunkEvent( + chunk_content=str(i), + route_node_state=route_node_state, + from_variable_selector=[next_node_id, "text"], + parallel_id=graph.node_parallel_mapping.get(next_node_id), + ) + + route_node_state.status = RouteNodeState.Status.SUCCESS + route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + yield NodeRunSucceededEvent( + route_node_state=route_node_state, + parallel_id=graph.node_parallel_mapping.get(next_node_id), + ) + + +def test_process(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm3-source-llm4-target", + "source": "llm3", + "target": "llm4", + }, + { + "id": "llm3-source-llm5-target", + "source": "llm3", + "target": "llm5", + }, + { + "id": "llm4-source-answer2-target", + "source": "llm4", + "target": "answer2", + }, + { + "id": "llm5-source-answer-target", + "source": "llm5", + "target": "answer", + }, + { + "id": "answer2-source-answer-target", + "source": "answer2", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "llm", + }, + "id": "llm4" + }, + { + "data": { + "type": "llm", + }, + "id": "llm5" + }, + { + "data": { + "type": "answer", + "title": "answer", + "answer": "a{{#llm2.text#}}b" + }, + "id": "answer", + }, + { + "data": { + "type": "answer", + "title": "answer2", + "answer": "c{{#llm3.text#}}d" + }, + "id": "answer2", + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + variable_pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + + answer_stream_processor = AnswerStreamProcessor( + graph=graph, + variable_pool=variable_pool + ) + + def graph_generator() -> Generator[GraphEngineEvent, None, None]: + # print("") + for event in _recursive_process(graph, "start"): + # print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunSucceededEvent): + if 'llm' in event.route_node_state.node_id: + variable_pool.append_variable( + event.route_node_state.node_id, + ["text"], + "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))) + ) + yield event + + result_generator = answer_stream_processor.process(graph_generator()) + stream_contents = "" + for event in result_generator: + # print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id, + # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) + if isinstance(event, NodeRunStreamChunkEvent): + stream_contents += event.chunk_content + pass + + assert stream_contents == "c012da01b" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 126040394c..aeed00f359 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -1,21 +1,75 @@ +import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType def test_execute_if_else_result_true(): - node = IfElseNode( + graph_config = { + "edges": [], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + } + ] + } + + graph = Graph.init( + graph_config=graph_config + ) + + init_params = GraphInitParams( tenant_id='1', app_id='1', + workflow_type=WorkflowType.WORKFLOW, workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0 + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) + pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) + pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) + pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde') + pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde') + pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc') + pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab') + pool.append_variable(node_id='start', variable_key_list=['is'], value='ab') + pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab') + pool.append_variable(node_id='start', variable_key_list=['empty'], value='') + pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa') + pool.append_variable(node_id='start', variable_key_list=['equals'], value=22) + pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23) + pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23) + pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21) + pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22) + pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21) + pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212') + + node = IfElseNode( + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + start_at=time.perf_counter() + ), config={ 'id': 'if-else', 'data': { @@ -116,34 +170,11 @@ def test_execute_if_else_result_true(): } ) - # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) - pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) - pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) - pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde') - pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde') - pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc') - pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab') - pool.append_variable(node_id='start', variable_key_list=['is'], value='ab') - pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab') - pool.append_variable(node_id='start', variable_key_list=['empty'], value='') - pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa') - pool.append_variable(node_id='start', variable_key_list=['equals'], value=22) - pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23) - pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23) - pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21) - pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22) - pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21) - pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212') - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs['result'] is True