mirror of https://github.com/langgenius/dify.git
finished answer stream output
This commit is contained in:
parent
7ad77e9e77
commit
dad1a967ee
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
|
|
@ -33,7 +33,6 @@ from core.app.entities.task_entities import (
|
|||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
|
|
@ -43,20 +42,16 @@ from core.app.entities.task_entities import (
|
|||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
|
|
@ -430,102 +425,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
**extras
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
answer_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.ANSWER.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in answer_node_configs:
|
||||
# get generate route for stream output
|
||||
answer_node_id = node_config['id']
|
||||
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
|
||||
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
# check if it's the first node in the iteration
|
||||
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
|
||||
if not target_node:
|
||||
return []
|
||||
|
||||
node_iteration_id = target_node.get('data', {}).get('iteration_id')
|
||||
# get iteration start node id
|
||||
for node in nodes:
|
||||
if node.get('id') == node_iteration_id:
|
||||
if node.get('data', {}).get('start_node_id') == target_node_id:
|
||||
return [target_node_id]
|
||||
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
|
|
@ -546,205 +445,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
break
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
value = None
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
route_chunk_node_id = value_selector[0]
|
||||
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
continue
|
||||
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
|
||||
iterator = iteration_state.inputs
|
||||
if not iterator:
|
||||
continue
|
||||
iterator_selector = iterator.get('iterator_selector', [])
|
||||
if value_selector[1] == 'index':
|
||||
value = iteration_state.current_index
|
||||
elif value_selector[1] == 'item':
|
||||
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
|
||||
iterator_selector
|
||||
) else None
|
||||
else:
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
break
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
# get route chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
|
||||
).first()
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
# get value from outputs
|
||||
value = None
|
||||
for key in value_selector[1:]:
|
||||
if not value:
|
||||
value = outputs.get(key) if outputs else None
|
||||
else:
|
||||
value = value.get(key)
|
||||
|
||||
if value is not None:
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, dict):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
if file_vars:
|
||||
file_var = file_vars[0]
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._message_to_stream_response(text, self._message.id)
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return True
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return True
|
||||
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
route_chunk = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position]
|
||||
|
||||
if route_chunk.type != 'var':
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ class NodeRunStartedEvent(BaseNodeEvent):
|
|||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
|
|
|
|||
|
|
@ -5,21 +5,24 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = Field(None, description="run condition")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = Field(None, description="parent parallel id")
|
||||
end_to_node_id: Optional[str] = Field(None, description="end to node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
|
|
@ -33,6 +36,10 @@ class Graph(BaseModel):
|
|||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
|
|
@ -41,8 +48,8 @@ class Graph(BaseModel):
|
|||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field(
|
||||
default_factory=dict,
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
|
||||
...,
|
||||
description="answer stream generate routes"
|
||||
)
|
||||
|
||||
|
|
@ -66,6 +73,7 @@ class Graph(BaseModel):
|
|||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
|
|
@ -79,6 +87,9 @@ class Graph(BaseModel):
|
|||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
|
|
@ -91,11 +102,12 @@ class Graph(BaseModel):
|
|||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=edge_config.get('target'),
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
|
|
@ -149,9 +161,9 @@ class Graph(BaseModel):
|
|||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes(
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping
|
||||
reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init graph
|
||||
|
|
@ -160,6 +172,7 @@ class Graph(BaseModel):
|
|||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ class GraphRuntimeState(BaseModel):
|
|||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
total_tokens: int = Field(0, description="total tokens")
|
||||
node_run_steps: int = Field(0, description="node run steps")
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
|
||||
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState, description="node run state")
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
|
|
|
|||
|
|
@ -28,6 +28,9 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
|
||||
# from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -81,6 +84,10 @@ class GraphEngine:
|
|||
try:
|
||||
# run graph
|
||||
generator = self._run(start_node_id=self.graph.root_node_id)
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
answer_stream_processor = AnswerStreamProcessor(self.graph)
|
||||
generator = answer_stream_processor.process(generator)
|
||||
|
||||
for item in generator:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
|
|
@ -314,8 +321,6 @@ class GraphEngine:
|
|||
|
||||
db.session.close()
|
||||
|
||||
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
|
||||
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
|
||||
try:
|
||||
|
|
@ -335,7 +340,7 @@ class GraphEngine:
|
|||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# append node output variables to variable pool
|
||||
|
|
@ -397,7 +402,7 @@ class GraphEngine:
|
|||
self.graph_runtime_state.variable_pool.append_variable(
|
||||
node_id=node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=variable_value
|
||||
value=variable_value # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import cast
|
|||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
GenerateRouteChunk,
|
||||
|
|
@ -29,7 +29,7 @@ class AnswerNode(BaseNode):
|
|||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamOutputManager.extract_generate_route_from_node_data(node_data)
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
answer = ''
|
||||
for part in generate_routes:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,203 @@
|
|||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selectors of answer nodes
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||
answer_generate_route[answer_node_id] = generate_route
|
||||
|
||||
# fetch answer dependencies
|
||||
answer_node_ids = list(answer_generate_route.keys())
|
||||
answer_dependencies = cls._fetch_answers_dependencies(
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
)
|
||||
|
||||
return AnswerStreamGenerateRoute(
|
||||
answer_generate_route=answer_generate_route,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = AnswerNodeData(**config.get("data", {}))
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _fetch_answers_dependencies(cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param answer_node_ids: answer node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
answer_dependencies: dict[str, list[str]] = {}
|
||||
for answer_node_id in answer_node_ids:
|
||||
if answer_dependencies.get(answer_node_id) is None:
|
||||
answer_dependencies[answer_node_id] = []
|
||||
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=answer_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
return answer_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_answer_dependencies(cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
:param answer_node_id: answer node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param answer_dependencies: answer dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_answer_dependencies(cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
:param answer_node_id: answer node id
|
||||
:param answer_node_ids: answer node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param answer_dependencies: answer dependencies
|
||||
:return:
|
||||
"""
|
||||
for edge in reverse_edge_mapping.get(current_node_id, []):
|
||||
source_node_id = edge.source_node_id
|
||||
if source_node_id == answer_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id in answer_node_ids:
|
||||
# is answer node
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._fetch_answer_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
|
@ -1,160 +0,0 @@
|
|||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamOutputManager:
|
||||
|
||||
@classmethod
|
||||
def init_stream_generate_routes(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||
) -> dict[str, AnswerStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||
streaming_node_ids = cls._get_streaming_node_ids(
|
||||
target_node_id=node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
if not streaming_node_ids:
|
||||
continue
|
||||
|
||||
for streaming_node_id in streaming_node_ids:
|
||||
stream_generate_routes[streaming_node_id] = AnswerStreamGenerateRoute(
|
||||
answer_node_id=node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _get_streaming_node_ids(cls,
|
||||
target_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
edge_mapping: dict[str, list["GraphEdge"]]) -> list[str]: # type: ignore[name-defined]
|
||||
"""
|
||||
Get answer stream node IDs.
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_graph_edges = []
|
||||
for graph_edges in edge_mapping.values():
|
||||
for graph_edge in graph_edges:
|
||||
if graph_edge.target_node_id == target_node_id:
|
||||
ingoing_graph_edges.append(graph_edge)
|
||||
|
||||
if not ingoing_graph_edges:
|
||||
return []
|
||||
|
||||
streaming_node_ids = []
|
||||
for ingoing_graph_edge in ingoing_graph_edges:
|
||||
source_node_id = ingoing_graph_edge.source_node_id
|
||||
source_node = node_id_config_mapping.get(source_node_id)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
# Current node (answer nodes / multi-branch nodes / iteration nodes) cannot be stream node.
|
||||
streaming_node_ids.append(target_node_id)
|
||||
elif node_type == NodeType.START.value:
|
||||
# Current node is START node, can be stream node.
|
||||
streaming_node_ids.append(source_node_id)
|
||||
else:
|
||||
# Find the stream node forward.
|
||||
sub_streaming_node_ids = cls._get_streaming_node_ids(
|
||||
target_node_id=source_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
if sub_streaming_node_ids:
|
||||
streaming_node_ids.extend(sub_streaming_node_ids)
|
||||
|
||||
return streaming_node_ids
|
||||
|
||||
@classmethod
|
||||
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = AnswerNodeData(**config.get("data", {}))
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
|
@ -0,0 +1,286 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor:
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_answer_node_ids
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping[node_id]:
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping[node_id]:
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
# all depends on answer node id not in rest node ids
|
||||
if not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get_variable_value(
|
||||
variable_selector=value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, dict):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
if file_vars:
|
||||
file_var = file_vars[0]
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
route_position = self.route_position[answer_node_id]
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
|
||||
|
||||
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
|
||||
continue
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
|
||||
@classmethod
|
||||
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = cls._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
elif isinstance(value, dict):
|
||||
file_var = cls._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(self, value: dict | list) -> Optional[dict]:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
|
@ -42,11 +42,21 @@ class TextGenerateRouteChunk(GenerateRouteChunk):
|
|||
text: str = Field(..., description="text")
|
||||
|
||||
|
||||
class AnswerNodeDoubleLink(BaseModel):
|
||||
node_id: str = Field(..., description="node id")
|
||||
source_node_ids: list[str] = Field(..., description="source node ids")
|
||||
target_node_ids: list[str] = Field(..., description="target node ids")
|
||||
|
||||
|
||||
class AnswerStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
AnswerStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str = Field(..., description="answer node ID")
|
||||
generate_route: list[GenerateRouteChunk] = Field(..., description="answer stream generate route")
|
||||
current_route_position: int = 0
|
||||
"""current generate route position"""
|
||||
answer_dependencies: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="answer dependencies (answer node id -> dependent answer node ids)"
|
||||
)
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
|
||||
...,
|
||||
description="answer generate route (answer node id -> generate route chunks)"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,29 +1,55 @@
|
|||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
node = AnswerNode(
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'answer',
|
||||
'data': {
|
||||
'title': '123',
|
||||
'type': 'answer',
|
||||
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
|
||||
}
|
||||
}
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
|
|
@ -34,11 +60,28 @@ def test_execute_answer():
|
|||
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
|
||||
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')
|
||||
|
||||
node = AnswerNode(
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
'id': 'answer',
|
||||
'data': {
|
||||
'title': '123',
|
||||
'type': 'answer',
|
||||
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
|
||||
|
||||
def test_init():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm4-target",
|
||||
"source": "llm3",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm5-target",
|
||||
"source": "llm3",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-answer2-target",
|
||||
"source": "llm4",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-answer-target",
|
||||
"source": "llm5",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "answer2-source-answer-target",
|
||||
"source": "answer2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer",
|
||||
"answer": "1{{#llm2.text#}}2"
|
||||
},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer2",
|
||||
"answer": "1{{#llm3.text#}}2"
|
||||
},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
answer_stream_generate_route = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=graph.node_id_config_mapping,
|
||||
reverse_edge_mapping=graph.reverse_edge_mapping
|
||||
)
|
||||
|
||||
assert answer_stream_generate_route.answer_dependencies['answer'] == ['answer2']
|
||||
assert answer_stream_generate_route.answer_dependencies['answer2'] == []
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
|
||||
|
||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
if next_node_id == 'start':
|
||||
yield from _publish_events(graph, next_node_id)
|
||||
|
||||
for edge in graph.edge_mapping.get(next_node_id, []):
|
||||
yield from _publish_events(graph, edge.target_node_id)
|
||||
|
||||
for edge in graph.edge_mapping.get(next_node_id, []):
|
||||
yield from _recursive_process(graph, edge.target_node_id)
|
||||
|
||||
|
||||
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
route_node_state = RouteNodeState(
|
||||
node_id=next_node_id,
|
||||
start_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
)
|
||||
|
||||
if 'llm' in next_node_id:
|
||||
length = int(next_node_id[-1])
|
||||
for i in range(0, length):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
chunk_content=str(i),
|
||||
route_node_state=route_node_state,
|
||||
from_variable_selector=[next_node_id, "text"],
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
)
|
||||
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
yield NodeRunSucceededEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
)
|
||||
|
||||
|
||||
def test_process():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm4-target",
|
||||
"source": "llm3",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-llm5-target",
|
||||
"source": "llm3",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-answer2-target",
|
||||
"source": "llm4",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-answer-target",
|
||||
"source": "llm5",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "answer2-source-answer-target",
|
||||
"source": "answer2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer",
|
||||
"answer": "a{{#llm2.text#}}b"
|
||||
},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer2",
|
||||
"answer": "c{{#llm3.text#}}d"
|
||||
},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
|
||||
answer_stream_processor = AnswerStreamProcessor(
|
||||
graph=graph,
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
||||
def graph_generator() -> Generator[GraphEngineEvent, None, None]:
|
||||
# print("")
|
||||
for event in _recursive_process(graph, "start"):
|
||||
# print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id,
|
||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if 'llm' in event.route_node_state.node_id:
|
||||
variable_pool.append_variable(
|
||||
event.route_node_state.node_id,
|
||||
["text"],
|
||||
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1])))
|
||||
)
|
||||
yield event
|
||||
|
||||
result_generator = answer_stream_processor.process(graph_generator())
|
||||
stream_contents = ""
|
||||
for event in result_generator:
|
||||
# print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id,
|
||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
stream_contents += event.chunk_content
|
||||
pass
|
||||
|
||||
assert stream_contents == "c012da01b"
|
||||
|
|
@ -1,21 +1,75 @@
|
|||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def test_execute_if_else_result_true():
|
||||
node = IfElseNode(
|
||||
graph_config = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
|
||||
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
|
||||
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
|
||||
|
||||
node = IfElseNode(
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
'id': 'if-else',
|
||||
'data': {
|
||||
|
|
@ -116,34 +170,11 @@ def test_execute_if_else_result_true():
|
|||
}
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
|
||||
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
|
||||
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['result'] is True
|
||||
|
|
|
|||
Loading…
Reference in New Issue