finished answer stream output

This commit is contained in:
takatost 2024-07-20 00:49:46 +08:00
parent 7ad77e9e77
commit dad1a967ee
15 changed files with 989 additions and 522 deletions

View File

@ -2,7 +2,7 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
@ -33,7 +33,6 @@ from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
@ -43,20 +42,16 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
@ -430,102 +425,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
**extras
)
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
@ -546,205 +445,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
self._task_state.current_stream_generate_state.current_route_position += 1
continue
route_chunk_node_id = value_selector[0]
if route_chunk_node_id == 'sys':
# system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector
) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key) if outputs else None
else:
value = value.get(key)
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
self._task_state.answer += text
yield self._message_to_stream_response(text, self._message.id)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.

View File

@ -50,7 +50,8 @@ class NodeRunStartedEvent(BaseNodeEvent):
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):

View File

@ -5,21 +5,24 @@ from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = Field(None, description="run condition")
run_condition: Optional[RunCondition] = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
end_to_node_id: Optional[str] = Field(None, description="end to node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
class Graph(BaseModel):
@ -33,6 +36,10 @@ class Graph(BaseModel):
default_factory=dict,
description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
@ -41,8 +48,8 @@ class Graph(BaseModel):
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field(
default_factory=dict,
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
...,
description="answer stream generate routes"
)
@ -66,6 +73,7 @@ class Graph(BaseModel):
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
@ -79,6 +87,9 @@ class Graph(BaseModel):
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
target_edge_ids.add(target_node_id)
# parse run condition
@ -91,11 +102,12 @@ class Graph(BaseModel):
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=edge_config.get('target'),
target_node_id=target_node_id,
run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get('nodes')
@ -149,9 +161,9 @@ class Graph(BaseModel):
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes(
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping
reverse_edge_mapping=reverse_edge_mapping
)
# init graph
@ -160,6 +172,7 @@ class Graph(BaseModel):
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes

View File

@ -8,7 +8,11 @@ class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
start_at: float = Field(..., description="start time")
total_tokens: int = Field(0, description="total tokens")
node_run_steps: int = Field(0, description="node run steps")
total_tokens: int = 0
"""total tokens"""
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@ -28,6 +28,9 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from extensions.ext_database import db
@ -81,6 +84,10 @@ class GraphEngine:
try:
# run graph
generator = self._run(start_node_id=self.graph.root_node_id)
if self.init_params.workflow_type == WorkflowType.CHAT:
answer_stream_processor = AnswerStreamProcessor(self.graph)
generator = answer_stream_processor.process(generator)
for item in generator:
yield item
if isinstance(item, NodeRunFailedEvent):
@ -314,8 +321,6 @@ class GraphEngine:
db.session.close()
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
self.graph_runtime_state.node_run_steps += 1
try:
@ -335,7 +340,7 @@ class GraphEngine:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
# append node output variables to variable pool
@ -397,7 +402,7 @@ class GraphEngine:
self.graph_runtime_state.variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
value=variable_value # type: ignore[arg-type]
)
# if variable_value is a dict, then recursively append variables

View File

@ -4,7 +4,7 @@ from typing import cast
from core.file.file_obj import FileVar
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
@ -29,7 +29,7 @@ class AnswerNode(BaseNode):
node_data = cast(AnswerNodeData, node_data)
# generate routes
generate_routes = AnswerStreamOutputManager.extract_generate_route_from_node_data(node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
answer = ''
for part in generate_routes:

View File

@ -0,0 +1,203 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route,
answer_dependencies=answer_dependencies
)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes: list[GenerateRouteChunk] = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]]
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)
@classmethod
def _fetch_answer_dependencies(cls,
current_node_id: str,
answer_node_id: str,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]]
) -> None:
"""
Fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
for edge in reverse_edge_mapping.get(current_node_id, []):
source_node_id = edge.source_node_id
if source_node_id == answer_node_id:
continue
if source_node_id in answer_node_ids:
# is answer node
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)

View File

@ -1,160 +0,0 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamOutputManager:
@classmethod
def init_stream_generate_routes(cls,
node_id_config_mapping: dict[str, dict],
edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> dict[str, AnswerStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
streaming_node_ids = cls._get_streaming_node_ids(
target_node_id=node_id,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping
)
if not streaming_node_ids:
continue
for streaming_node_id in streaming_node_ids:
stream_generate_routes[streaming_node_id] = AnswerStreamGenerateRoute(
answer_node_id=node_id,
generate_route=generate_route
)
return stream_generate_routes
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes: list[GenerateRouteChunk] = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _get_streaming_node_ids(cls,
target_node_id: str,
node_id_config_mapping: dict[str, dict],
edge_mapping: dict[str, list["GraphEdge"]]) -> list[str]: # type: ignore[name-defined]
"""
Get answer stream node IDs.
:param target_node_id: target node ID
:return:
"""
# fetch all ingoing edges from source node
ingoing_graph_edges = []
for graph_edges in edge_mapping.values():
for graph_edge in graph_edges:
if graph_edge.target_node_id == target_node_id:
ingoing_graph_edges.append(graph_edge)
if not ingoing_graph_edges:
return []
streaming_node_ids = []
for ingoing_graph_edge in ingoing_graph_edges:
source_node_id = ingoing_graph_edge.source_node_id
source_node = node_id_config_mapping.get(source_node_id)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
# Current node (answer nodes / multi-branch nodes / iteration nodes) cannot be stream node.
streaming_node_ids.append(target_node_id)
elif node_type == NodeType.START.value:
# Current node is START node, can be stream node.
streaming_node_ids.append(source_node_id)
else:
# Find the stream node forward.
sub_streaming_node_ids = cls._get_streaming_node_ids(
target_node_id=source_node_id,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping
)
if sub_streaming_node_ids:
streaming_node_ids.extend(sub_streaming_node_ids)
return streaming_node_ids
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys

View File

@ -0,0 +1,286 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, cast
from core.file.file_obj import FileVar
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor:
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStreamChunkEvent):
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_answer_node_ids
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping[node_id]:
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping[node_id]:
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get_variable_value(
variable_selector=value_selector
)
if value is None:
break
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
yield NodeRunStreamChunkEvent(
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
route_position = self.route_position[answer_node_id]
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids
@classmethod
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files = []
if isinstance(value, list):
for item in value:
file_var = cls._get_file_var_from_value(item)
if file_var:
files.append(file_var)
elif isinstance(value, dict):
file_var = cls._get_file_var_from_value(value)
if file_var:
files.append(file_var)
return files
@classmethod
def _get_file_var_from_value(self, value: dict | list) -> Optional[dict]:
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()
return None

View File

@ -42,11 +42,21 @@ class TextGenerateRouteChunk(GenerateRouteChunk):
text: str = Field(..., description="text")
class AnswerNodeDoubleLink(BaseModel):
node_id: str = Field(..., description="node id")
source_node_ids: list[str] = Field(..., description="source node ids")
target_node_ids: list[str] = Field(..., description="target node ids")
class AnswerStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
AnswerStreamGenerateRoute entity
"""
answer_node_id: str = Field(..., description="answer node ID")
generate_route: list[GenerateRouteChunk] = Field(..., description="answer stream generate route")
current_route_position: int = 0
"""current generate route position"""
answer_dependencies: dict[str, list[str]] = Field(
...,
description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
...,
description="answer generate route (answer node id -> generate route chunks)"
)

View File

@ -1,29 +1,55 @@
import time
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_execute_answer():
node = AnswerNode(
graph_config = {
"edges": [
{
"id": "start-source-llm-target",
"source": "start",
"target": "llm",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm"
},
]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1',
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'answer',
'data': {
'title': '123',
'type': 'answer',
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
}
}
call_depth=0
)
# construct variable pool
@ -34,11 +60,28 @@ def test_execute_answer():
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')
node = AnswerNode(
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
start_at=time.perf_counter()
),
config={
'id': 'answer',
'data': {
'title': '123',
'type': 'answer',
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
}
}
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."

View File

@ -0,0 +1,125 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
def test_init():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm3-source-llm4-target",
"source": "llm3",
"target": "llm4",
},
{
"id": "llm3-source-llm5-target",
"source": "llm3",
"target": "llm5",
},
{
"id": "llm4-source-answer2-target",
"source": "llm4",
"target": "answer2",
},
{
"id": "llm5-source-answer-target",
"source": "llm5",
"target": "answer",
},
{
"id": "answer2-source-answer-target",
"source": "answer2",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "llm",
},
"id": "llm4"
},
{
"data": {
"type": "llm",
},
"id": "llm5"
},
{
"data": {
"type": "answer",
"title": "answer",
"answer": "1{{#llm2.text#}}2"
},
"id": "answer",
},
{
"data": {
"type": "answer",
"title": "answer2",
"answer": "1{{#llm3.text#}}2"
},
"id": "answer2",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
answer_stream_generate_route = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=graph.node_id_config_mapping,
reverse_edge_mapping=graph.reverse_edge_mapping
)
assert answer_stream_generate_route.answer_dependencies['answer'] == ['answer2']
assert answer_stream_generate_route.answer_dependencies['answer2'] == []

View File

@ -0,0 +1,206 @@
from collections.abc import Generator
from datetime import datetime, timezone
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
if next_node_id == 'start':
yield from _publish_events(graph, next_node_id)
for edge in graph.edge_mapping.get(next_node_id, []):
yield from _publish_events(graph, edge.target_node_id)
for edge in graph.edge_mapping.get(next_node_id, []):
yield from _recursive_process(graph, edge.target_node_id)
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
route_node_state = RouteNodeState(
node_id=next_node_id,
start_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
)
if 'llm' in next_node_id:
length = int(next_node_id[-1])
for i in range(0, length):
yield NodeRunStreamChunkEvent(
chunk_content=str(i),
route_node_state=route_node_state,
from_variable_selector=[next_node_id, "text"],
parallel_id=graph.node_parallel_mapping.get(next_node_id),
)
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield NodeRunSucceededEvent(
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
)
def test_process():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm3-source-llm4-target",
"source": "llm3",
"target": "llm4",
},
{
"id": "llm3-source-llm5-target",
"source": "llm3",
"target": "llm5",
},
{
"id": "llm4-source-answer2-target",
"source": "llm4",
"target": "answer2",
},
{
"id": "llm5-source-answer-target",
"source": "llm5",
"target": "answer",
},
{
"id": "answer2-source-answer-target",
"source": "answer2",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "llm",
},
"id": "llm4"
},
{
"data": {
"type": "llm",
},
"id": "llm5"
},
{
"data": {
"type": "answer",
"title": "answer",
"answer": "a{{#llm2.text#}}b"
},
"id": "answer",
},
{
"data": {
"type": "answer",
"title": "answer2",
"answer": "c{{#llm3.text#}}d"
},
"id": "answer2",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
variable_pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
answer_stream_processor = AnswerStreamProcessor(
graph=graph,
variable_pool=variable_pool
)
def graph_generator() -> Generator[GraphEngineEvent, None, None]:
# print("")
for event in _recursive_process(graph, "start"):
# print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id,
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunSucceededEvent):
if 'llm' in event.route_node_state.node_id:
variable_pool.append_variable(
event.route_node_state.node_id,
["text"],
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1])))
)
yield event
result_generator = answer_stream_processor.process(graph_generator())
stream_contents = ""
for event in result_generator:
# print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id,
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunStreamChunkEvent):
stream_contents += event.chunk_content
pass
assert stream_contents == "c012da01b"

View File

@ -1,21 +1,75 @@
import time
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_execute_if_else_result_true():
node = IfElseNode(
graph_config = {
"edges": [],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
}
]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1',
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
node = IfElseNode(
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
start_at=time.perf_counter()
),
config={
'id': 'if-else',
'data': {
@ -116,34 +170,11 @@ def test_execute_if_else_result_true():
}
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] is True