diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 714503df86..85a1799719 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -7,6 +7,8 @@ from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute +from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter +from core.workflow.nodes.end.entities import EndStreamParam class GraphEdge(BaseModel): @@ -52,6 +54,10 @@ class Graph(BaseModel): ..., description="answer stream generate routes" ) + end_stream_param: EndStreamParam = Field( + ..., + description="end stream param" + ) @classmethod def init(cls, @@ -166,6 +172,12 @@ class Graph(BaseModel): reverse_edge_mapping=reverse_edge_mapping ) + # init end stream param + end_stream_param = EndStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping + ) + # init graph graph = cls( root_node_id=root_node_id, @@ -175,7 +187,8 @@ class Graph(BaseModel): reverse_edge_mapping=reverse_edge_mapping, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping, - answer_stream_generate_routes=answer_stream_generate_routes + answer_stream_generate_routes=answer_stream_generate_routes, + end_stream_param=end_stream_param ) return graph diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 6e7dcb7ede..6cb80091c9 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -167,37 +167,3 @@ class AnswerStreamGeneratorRouter: reverse_edge_mapping=reverse_edge_mapping, answer_dependencies=answer_dependencies ) - - @classmethod - def _fetch_answer_dependencies(cls, - current_node_id: str, - answer_node_id: str, - answer_node_ids: list[str], - reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] - answer_dependencies: dict[str, list[str]] - ) -> None: - """ - Fetch answer dependencies - :param current_node_id: current node id - :param answer_node_id: answer node id - :param answer_node_ids: answer node ids - :param reverse_edge_mapping: reverse edge mapping - :param answer_dependencies: answer dependencies - :return: - """ - for edge in reverse_edge_mapping.get(current_node_id, []): - source_node_id = edge.source_node_id - if source_node_id == answer_node_id: - continue - - if source_node_id in answer_node_ids: - # is answer node - answer_dependencies[answer_node_id].append(source_node_id) - else: - cls._fetch_answer_dependencies( - current_node_id=source_node_id, - answer_node_id=answer_node_id, - answer_node_ids=answer_node_ids, - reverse_edge_mapping=reverse_edge_mapping, - answer_dependencies=answer_dependencies - ) diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py new file mode 100644 index 0000000000..ac386ee796 --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -0,0 +1,142 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam + + +class EndStreamGeneratorRouter: + + @classmethod + def init(cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] + ) -> EndStreamParam: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selector of end nodes + end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} + for end_node_id, node_config in node_id_config_mapping.items(): + if not node_config.get('data', {}).get('type') == NodeType.END.value: + continue + + # get generate route for stream output + stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config) + end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors + + # fetch end dependencies + end_node_ids = list(end_stream_variable_selectors_mapping.keys()) + end_dependencies = cls._fetch_ends_dependencies( + end_node_ids=end_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping + ) + + return EndStreamParam( + end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, + end_dependencies=end_dependencies + ) + + @classmethod + def extract_stream_variable_selector_from_node_data(cls, + node_id_config_mapping: dict[str, dict], + node_data: EndNodeData) -> list[list[str]]: + """ + Extract stream variable selector from node data + :param node_id_config_mapping: node id config mapping + :param node_data: node data object + :return: + """ + variable_selectors = node_data.outputs + + value_selectors = [] + for variable_selector in variable_selectors: + if not variable_selector.value_selector: + continue + + node_id = variable_selector.value_selector[0] + if node_id != 'sys' and node_id in node_id_config_mapping: + node = node_id_config_mapping[node_id] + node_type = node.get('data', {}).get('type') + if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': + value_selectors.append(variable_selector.value_selector) + + # remove duplicates + value_selectors = list(set(value_selectors)) + + return value_selectors + + @classmethod + def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \ + -> list[list[str]]: + """ + Extract stream variable selector from node config + :param node_id_config_mapping: node id config mapping + :param config: node config + :return: + """ + node_data = EndNodeData(**config.get("data", {})) + return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) + + @classmethod + def _fetch_ends_dependencies(cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict] + ) -> dict[str, list[str]]: + """ + Fetch end dependencies + :param end_node_ids: end node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + end_dependencies: dict[str, list[str]] = {} + for end_node_id in end_node_ids: + if end_dependencies.get(end_node_id) is None: + end_dependencies[end_node_id] = [] + + cls._recursive_fetch_end_dependencies( + current_node_id=end_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies + ) + + return end_dependencies + + @classmethod + def _recursive_fetch_end_dependencies(cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]] + ) -> None: + """ + Recursive fetch end dependencies + :param current_node_id: current node id + :param end_node_id: end node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param end_dependencies: end dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type') + if source_node_type in ( + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + ): + end_dependencies[end_node_id].append(source_node_id) + else: + cls._recursive_fetch_end_dependencies( + current_node_id=source_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies + ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py new file mode 100644 index 0000000000..960d04988e --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -0,0 +1,207 @@ +import logging +from collections.abc import Generator +from typing import cast + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk + +logger = logging.getLogger(__name__) + + +class EndStreamProcessor: + + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + self.graph = graph + self.variable_pool = variable_pool + self.stream_param = graph.end_stream_param + self.end_streamed_variable_selectors: dict[str, list[str]] = { + end_node_id: [] for end_node_id in graph.end_stream_param.end_stream_variable_selector_mapping + } + + self.rest_node_ids = graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + + def process(self, + generator: Generator[GraphEngineEvent, None, None] + ) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStreamChunkEvent): + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) + self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] = stream_out_answer_node_ids + + for _ in stream_out_answer_node_ids: + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[answer_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + finished_node_id = event.route_node_state.node_id + + if finished_node_id not in self.rest_node_ids: + return + + # remove finished node id + self.rest_node_ids.remove(finished_node_id) + + run_result = event.route_node_state.node_run_result + if not run_result: + return + + if run_result.edge_source_handle: + reachable_node_ids = [] + unreachable_first_node_ids = [] + for edge in self.graph.edge_mapping[finished_node_id]: + if (edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify): + reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + continue + else: + unreachable_first_node_ids.append(edge.target_node_id) + + for node_id in unreachable_first_node_ids: + self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + node_ids = [] + for edge in self.graph.edge_mapping.get(node_id, []): + node_ids.append(edge.target_node_id) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + return node_ids + + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + """ + remove target node ids until merge + """ + self.rest_node_ids.remove(node_id) + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id in reachable_node_ids: + continue + + self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) + + def _generate_stream_outputs_when_node_finished(self, + event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for answer_node_id, position in self.route_position.items(): + # all depends on answer node id not in rest node ids + if (event.route_node_state.node_id != answer_node_id + and (answer_node_id not in self.rest_node_ids + or not all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): + continue + + route_position = self.route_position[answer_node_id] + route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + yield NodeRunStreamChunkEvent( + chunk_content=route_chunk.text, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + ) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + if not value_selector: + break + + value = self.variable_pool.get( + value_selector + ) + + if value is None: + break + + text = value.markdown + + if text: + yield NodeRunStreamChunkEvent( + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + ) + + self.route_position[answer_node_id] += 1 + + def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_answer_node_ids = [] + for answer_node_id, position in self.route_position.items(): + if answer_node_id not in self.rest_node_ids: + continue + + # all depends on answer node id not in rest node ids + if all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + route_position = self.route_position[answer_node_id] + if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + continue + + route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] + + if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: + continue + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_answer_node_ids.append(answer_node_id) + + return stream_out_answer_node_ids diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index ad4fc8f04f..a0edf7b579 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,3 +1,5 @@ +from pydantic import BaseModel, Field + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData): END Node Data. """ outputs: list[VariableSelector] + + +class EndStreamParam(BaseModel): + """ + EndStreamParam entity + """ + end_dependencies: dict[str, list[str]] = Field( + ..., + description="end dependencies (end node id -> dependent node ids)" + ) + end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( + ..., + description="end stream variable selector mapping (end node id -> stream variable selectors)" + )