mirror of https://github.com/langgenius/dify.git
fix iteration
This commit is contained in:
parent
ae351bd40e
commit
a31feacf28
|
|
@ -104,9 +104,14 @@ class GraphEngine:
|
|||
)
|
||||
|
||||
for item in generator:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
try:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Graph run failed: {str(e)}")
|
||||
yield GraphRunFailedEvent(reason=str(e))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
|
|
@ -115,6 +120,7 @@ class GraphEngine:
|
|||
yield GraphRunFailedEvent(reason=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when graph running")
|
||||
yield GraphRunFailedEvent(reason=str(e))
|
||||
raise e
|
||||
|
||||
|
|
@ -182,7 +188,22 @@ class GraphEngine:
|
|||
break
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
next_node_id = edge_mappings[0].target_node_id
|
||||
edge = edge_mappings[0]
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
target_node_id=edge.target_node_id,
|
||||
)
|
||||
|
||||
if not result:
|
||||
break
|
||||
|
||||
next_node_id = edge.target_node_id
|
||||
else:
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from core.file.file_obj import FileVar
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
|
@ -31,7 +32,12 @@ class AnswerStreamProcessor:
|
|||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
|
|
@ -99,6 +105,9 @@ class AnswerStreamProcessor:
|
|||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
|
@ -107,6 +116,9 @@ class AnswerStreamProcessor:
|
|||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from collections.abc import Generator
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
|
@ -26,7 +27,12 @@ class EndStreamProcessor:
|
|||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStreamChunkEvent):
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
|
|
@ -87,6 +93,9 @@ class EndStreamProcessor:
|
|||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
|
@ -95,6 +104,9 @@ class EndStreamProcessor:
|
|||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
|
|
|
|||
|
|
@ -5,15 +5,13 @@ from configs import dify_config
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import GraphRunFailedEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.event import BaseGraphEvent, GraphRunFailedEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.workflow_entry import WorkflowRunFailedError
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -74,7 +72,7 @@ class IterationNode(BaseNode):
|
|||
Condition(
|
||||
variable_selector=[self.node_id, "index"],
|
||||
comparison_operator="<",
|
||||
value=len(iterator_list_value)
|
||||
value=str(len(iterator_list_value))
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
@ -93,6 +91,7 @@ class IterationNode(BaseNode):
|
|||
)
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
|
|
@ -114,10 +113,11 @@ class IterationNode(BaseNode):
|
|||
rst = graph_engine.run()
|
||||
outputs: list[Any] = []
|
||||
for event in rst:
|
||||
yield event
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
|
||||
# handle iteration run result
|
||||
if event.node_id in iteration_leaf_node_ids:
|
||||
if event.route_node_state.node_id in iteration_leaf_node_ids:
|
||||
# append to iteration output variable list
|
||||
outputs.append(variable_pool.get_any(self.node_data.output_selector))
|
||||
|
||||
|
|
@ -132,13 +132,23 @@ class IterationNode(BaseNode):
|
|||
next_index
|
||||
)
|
||||
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[next_index]
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[next_index]
|
||||
)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
raise WorkflowRunFailedError(event.reason)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.reason,
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
yield event
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
|
@ -148,14 +158,6 @@ class IterationNode(BaseNode):
|
|||
}
|
||||
)
|
||||
)
|
||||
except WorkflowRunFailedError as e:
|
||||
# iteration run failed
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# iteration run failed
|
||||
logger.exception("Iteration run failed")
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
|||
|
||||
@patch('extensions.ext_database.db.session.remove')
|
||||
@patch('extensions.ext_database.db.session.close')
|
||||
def test_run_parallel_in_workflow(mock_close, mock_remove, mocker):
|
||||
def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
|
|
|
|||
|
|
@ -0,0 +1,209 @@
|
|||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, SystemVariable, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def test_run():
|
||||
graph_config = {
|
||||
"edges": [{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
}, {
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
}, {
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
}, {
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
}, {
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
}, {
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
}],
|
||||
"nodes": [{
|
||||
"data": {
|
||||
"title": "Start",
|
||||
"type": "start",
|
||||
"variables": []
|
||||
},
|
||||
"id": "start"
|
||||
}, {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}, {
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer"
|
||||
},
|
||||
"id": "answer-2"
|
||||
}, {
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{
|
||||
"value_selector": ["sys", "query"],
|
||||
"variable": "arg1"
|
||||
}]
|
||||
},
|
||||
"id": "tt",
|
||||
}, {
|
||||
"data": {
|
||||
"answer": "{{#iteration-1.output#}}88888",
|
||||
"title": "answer 3",
|
||||
"type": "answer"
|
||||
},
|
||||
"id": "answer-3",
|
||||
}, {
|
||||
"data": {
|
||||
"conditions": [{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"]
|
||||
}],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else"
|
||||
},
|
||||
"id": "if-else",
|
||||
}, {
|
||||
"data": {
|
||||
"answer": "no hi",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 4",
|
||||
"type": "answer"
|
||||
},
|
||||
"id": "answer-4",
|
||||
}, {
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai"
|
||||
},
|
||||
"parameters": [{
|
||||
"description": "test",
|
||||
"name": "list_output",
|
||||
"required": False,
|
||||
"type": "array[string]"
|
||||
}],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor"
|
||||
},
|
||||
"id": "pe",
|
||||
}]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'dify',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: '1'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['pe', 'list_output'], ["dify-1", "dify-2"])
|
||||
|
||||
iteration_node = IterationNode(
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
)
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'iterator_selector': 'dify'
|
||||
},
|
||||
outputs={
|
||||
'output': 'dify 123'
|
||||
}
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
with patch.object(TemplateTransformNode, '_run', new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
||||
count = 0
|
||||
for item in result:
|
||||
# print(type(item), item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
|
||||
assert count == 15
|
||||
Loading…
Reference in New Issue