From f4eb7cd0370e048ea96023d24d6934fe2ec32297 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 25 Jul 2024 04:03:53 +0800 Subject: [PATCH] add end stream output test --- .../workflow/graph_engine/graph_engine.py | 26 +- .../nodes/answer/answer_stream_processor.py | 4 +- api/core/workflow/nodes/end/end_node.py | 43 --- .../nodes/end/end_stream_generate_router.py | 4 +- .../nodes/end/end_stream_processor.py | 114 ++------ .../graph_engine/test_graph_engine.py | 263 +++++++++++++++++- 6 files changed, 300 insertions(+), 154 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 7fcdc46223..dcbedbeff8 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -29,6 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor # from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent @@ -82,14 +83,21 @@ class GraphEngine: yield GraphRunStartedEvent() try: - # run graph - generator = self._run(start_node_id=self.graph.root_node_id) if self.init_params.workflow_type == WorkflowType.CHAT: - answer_stream_processor = AnswerStreamProcessor( + stream_processor = AnswerStreamProcessor( graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool ) - generator = answer_stream_processor.process(generator) + else: + stream_processor = EndStreamProcessor( + graph=self.graph, + variable_pool=self.graph_runtime_state.variable_pool + ) + + # run graph + generator = stream_processor.process( + self._run(start_node_id=self.graph.root_node_id) + ) for item in generator: yield item @@ -151,6 +159,11 @@ class GraphEngine: ) raise e + # It may not be necessary, but it is necessary. :) + if (self.graph.node_id_config_mapping[next_node_id] + .get("data", {}).get("type", "").lower() == NodeType.END.value): + break + previous_route_node_state = route_node_state # get next node ids @@ -160,11 +173,6 @@ class GraphEngine: if len(edge_mappings) == 1: next_node_id = edge_mappings[0].target_node_id - - # It may not be necessary, but it is necessary. :) - if (self.graph.node_id_config_mapping[next_node_id] - .get("data", {}).get("type", "").lower() == NodeType.END.value): - break else: if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 851a66c9ba..9bec7e9fc2 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -66,6 +66,7 @@ class AnswerStreamProcessor: for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): self.route_position[answer_node_id] = 0 self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: finished_node_id = event.route_node_state.node_id @@ -179,14 +180,13 @@ class AnswerStreamProcessor: return [] stream_out_answer_node_ids = [] - for answer_node_id, position in self.route_position.items(): + for answer_node_id, route_position in self.route_position.items(): if answer_node_id not in self.rest_node_ids: continue # all depends on answer node id not in rest node ids if all(dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): - route_position = self.route_position[answer_node_id] if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): continue diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 9ab6ef610e..fc24873d16 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -31,49 +31,6 @@ class EndNode(BaseNode): outputs=outputs ) - @classmethod - def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]: - """ - Extract generate nodes - :param graph: graph - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(EndNodeData, node_data) - - return cls.extract_generate_nodes_from_node_data(graph, node_data) - - @classmethod - def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]: - """ - Extract generate nodes from node data - :param graph: graph - :param node_data: node data object - :return: - """ - nodes = graph.get('nodes', []) - node_mapping = {node.get('id'): node for node in nodes} - - variable_selectors = node_data.outputs - - generate_nodes = [] - for variable_selector in variable_selectors: - if not variable_selector.value_selector: - continue - - node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_mapping: - node = node_mapping[node_id] - node_type = node.get('data', {}).get('type') - if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': - generate_nodes.append(node_id) - - # remove duplicates - generate_nodes = list(set(generate_nodes)) - - return generate_nodes - @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index ac386ee796..d2c9578019 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -61,7 +61,9 @@ class EndStreamGeneratorRouter: value_selectors.append(variable_selector.value_selector) # remove duplicates - value_selectors = list(set(value_selectors)) + value_selector_tuples = [tuple(item) for item in value_selectors] + unique_value_selector_tuples = list(set(value_selector_tuples)) + value_selectors = [list(item) for item in unique_value_selector_tuples] return value_selectors diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 960d04988e..60bb97c0b1 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -1,6 +1,5 @@ import logging from collections.abc import Generator -from typing import cast from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( @@ -9,7 +8,6 @@ from core.workflow.graph_engine.entities.event import ( NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk logger = logging.getLogger(__name__) @@ -20,10 +18,7 @@ class EndStreamProcessor: self.graph = graph self.variable_pool = variable_pool self.stream_param = graph.end_stream_param - self.end_streamed_variable_selectors: dict[str, list[str]] = { - end_node_id: [] for end_node_id in graph.end_stream_param.end_stream_variable_selector_mapping - } - + self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy() self.rest_node_ids = graph.node_ids.copy() self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} @@ -33,43 +28,37 @@ class EndStreamProcessor: for event in generator: if isinstance(event, NodeRunStreamChunkEvent): if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: - stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ + stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ event.route_node_state.node_id ] else: - stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) + stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) self.current_stream_chunk_generating_node_ids[ event.route_node_state.node_id - ] = stream_out_answer_node_ids + ] = stream_out_end_node_ids - for _ in stream_out_answer_node_ids: + for _ in stream_out_end_node_ids: yield event elif isinstance(event, NodeRunSucceededEvent): yield event if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: - # update self.route_position after all stream event finished - for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: - self.route_position[answer_node_id] += 1 - del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] # remove unreachable nodes self._remove_unreachable_nodes(event) - - # generate stream outputs - yield from self._generate_stream_outputs_when_node_finished(event) else: yield event def reset(self) -> None: - self.route_position = {} - for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): - self.route_position[answer_node_id] = 0 + self.end_streamed_variable_selectors = {} + self.end_streamed_variable_selectors: dict[str, list[str]] = { + end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping + } self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: finished_node_id = event.route_node_state.node_id - if finished_node_id not in self.rest_node_ids: return @@ -113,59 +102,7 @@ class EndStreamProcessor: self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) - def _generate_stream_outputs_when_node_finished(self, - event: NodeRunSucceededEvent - ) -> Generator[GraphEngineEvent, None, None]: - """ - Generate stream outputs. - :param event: node run succeeded event - :return: - """ - for answer_node_id, position in self.route_position.items(): - # all depends on answer node id not in rest node ids - if (event.route_node_state.node_id != answer_node_id - and (answer_node_id not in self.rest_node_ids - or not all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): - continue - - route_position = self.route_position[answer_node_id] - route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] - - for route_chunk in route_chunks: - if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - yield NodeRunStreamChunkEvent( - chunk_content=route_chunk.text, - route_node_state=event.route_node_state, - parallel_id=event.parallel_id, - ) - else: - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - if not value_selector: - break - - value = self.variable_pool.get( - value_selector - ) - - if value is None: - break - - text = value.markdown - - if text: - yield NodeRunStreamChunkEvent( - chunk_content=text, - from_variable_selector=value_selector, - route_node_state=event.route_node_state, - parallel_id=event.parallel_id, - ) - - self.route_position[answer_node_id] += 1 - - def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: """ Is stream out support :param event: queue text chunk event @@ -178,30 +115,17 @@ class EndStreamProcessor: if not stream_output_value_selector: return [] - stream_out_answer_node_ids = [] - for answer_node_id, position in self.route_position.items(): - if answer_node_id not in self.rest_node_ids: + stream_out_end_node_ids = [] + for end_node_id, variable_selectors in self.end_streamed_variable_selectors.items(): + if end_node_id not in self.rest_node_ids: continue - # all depends on answer node id not in rest node ids + # all depends on end node id not in rest node ids if all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): - route_position = self.route_position[answer_node_id] - if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + for dep_id in self.stream_param.end_dependencies[end_node_id]): + if stream_output_value_selector not in variable_selectors: continue - route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] + stream_out_end_node_ids.append(end_node_id) - if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: - continue - - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - - # check chunk node id is before current node id or equal to current node id - if value_selector != stream_output_value_selector: - continue - - stream_out_answer_node_ids.append(answer_node_id) - - return stream_out_answer_node_ids + return stream_out_end_node_ids diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 75e70ee65b..d69dcc1c28 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -1,7 +1,7 @@ from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import SystemVariable, UserFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, SystemVariable, UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( BaseNodeEvent, @@ -16,12 +16,267 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.graph_engine import GraphEngine -from models.workflow import WorkflowType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.llm_node import LLMNode +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @patch('extensions.ext_database.db.session.remove') @patch('extensions.ext_database.db.session.close') -def test_run_parallel(mock_close, mock_remove): +def test_run_parallel_in_workflow(mock_close, mock_remove, mocker): + graph_config = { + "edges": [ + { + "id": "1", + "source": "start", + "target": "llm1", + }, + { + "id": "2", + "source": "llm1", + "target": "llm2", + }, + { + "id": "3", + "source": "llm1", + "target": "llm3", + }, + { + "id": "4", + "source": "llm2", + "target": "end1", + }, + { + "id": "5", + "source": "llm3", + "target": "end2", + } + ], + "nodes": [ + { + "data": { + "type": "start", + "title": "start", + "variables": [{ + "label": "query", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "query" + }] + }, + "id": "start" + }, + { + "data": { + "type": "llm", + "title": "llm1", + "context": { + "enabled": False, + "variable_selector": [] + }, + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai" + }, + "prompt_template": [{ + "role": "system", + "text": "say hi" + }, { + "role": "user", + "text": "{{#start.query#}}" + }], + "vision": { + "configs": { + "detail": "high" + }, + "enabled": False + } + }, + "id": "llm1" + }, + { + "data": { + "type": "llm", + "title": "llm2", + "context": { + "enabled": False, + "variable_selector": [] + }, + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai" + }, + "prompt_template": [{ + "role": "system", + "text": "say bye" + }, { + "role": "user", + "text": "{{#start.query#}}" + }], + "vision": { + "configs": { + "detail": "high" + }, + "enabled": False + } + }, + "id": "llm2", + }, + { + "data": { + "type": "llm", + "title": "llm3", + "context": { + "enabled": False, + "variable_selector": [] + }, + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai" + }, + "prompt_template": [{ + "role": "system", + "text": "say good morning" + }, { + "role": "user", + "text": "{{#start.query#}}" + }], + "vision": { + "configs": { + "detail": "high" + }, + "enabled": False + } + }, + "id": "llm3", + }, + { + "data": { + "type": "end", + "title": "end1", + "outputs": [{ + "value_selector": ["llm2", "text"], + "variable": "result2" + }, { + "value_selector": ["start", "query"], + "variable": "query" + }], + }, + "id": "end1", + }, + { + "data": { + "type": "end", + "title": "end2", + "outputs": [{ + "value_selector": ["llm1", "text"], + "variable": "result1" + }, { + "value_selector": ["llm3", "text"], + "variable": "result3" + }], + }, + "id": "end2", + } + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + variable_pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + SystemVariable.USER_ID: 'aaa' + }, user_inputs={ + "query": "hi" + }) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="333", + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200 + ) + + def llm_generator(self): + contents = [ + 'hi', + 'bye', + 'good morning' + ] + + yield RunStreamChunkEvent( + chunk_content=contents[int(self.node_id[-1]) - 1], + from_variable_selector=[self.node_id, 'text'] + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + process_data={}, + outputs={}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: 1, + NodeRunMetadataKey.TOTAL_PRICE: 1, + NodeRunMetadataKey.CURRENCY: 'USD' + } + ) + ) + + print("") + + with patch.object(LLMNode, '_run', new=llm_generator): + items = [] + generator = graph_engine.run() + for item in generator: + print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + 'llm2', 'llm3', 'end1', 'end2' + ]: + assert item.parallel_id is not None + + assert len(items) == 17 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == 'start' + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == 'start' + + +@patch('extensions.ext_database.db.session.remove') +@patch('extensions.ext_database.db.session.close') +def test_run_parallel_in_chatflow(mock_close, mock_remove): graph_config = { "edges": [ { @@ -291,7 +546,7 @@ def test_run_branch(mock_close, mock_remove): items = [] generator = graph_engine.run() for item in generator: - print(type(item), item) + # print(type(item), item) items.append(item) assert len(items) == 10