add parallel branch events

This commit is contained in:
takatost 2024-07-26 20:27:17 +08:00
parent 483f71f03c
commit 63addf8c94
7 changed files with 50 additions and 32 deletions

View File

@ -98,7 +98,7 @@ class AdvancedChatAppRunner(AppRunner):
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow_engine_manager.run(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT

View File

@ -67,7 +67,7 @@ class WorkflowAppRunner:
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow_engine_manager.run(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT

View File

@ -67,23 +67,24 @@ class NodeRunFailedEvent(BaseNodeEvent):
###########################################
# Parallel Events
# Parallel Branch Events
###########################################
class BaseParallelEvent(GraphEngineEvent):
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
parallel_start_node_id: str = Field(..., description="parallel start node id")
class ParallelRunStartedEvent(BaseParallelEvent):
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
pass
class ParallelRunSucceededEvent(BaseParallelEvent):
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
pass
class ParallelRunFailedEvent(BaseParallelEvent):
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
reason: str = Field(..., description="failed reason")
@ -113,4 +114,4 @@ class IterationRunFailedEvent(BaseIterationEvent):
reason: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent

View File

@ -23,6 +23,9 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
@ -237,14 +240,12 @@ class GraphEngine:
# new thread
for edge in edge_mappings:
run_thread = threading.Thread(target=self._run_parallel_node, kwargs={
threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(),
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'q': q
})
run_thread.start()
}).start()
succeeded_count = 0
while True:
@ -253,16 +254,15 @@ class GraphEngine:
if event is None:
break
if isinstance(event, GraphRunSucceededEvent):
yield event
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(edge_mappings):
break
continue
elif isinstance(event, GraphRunFailedEvent):
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.reason)
else:
yield event
except queue.Empty:
continue
@ -286,6 +286,11 @@ class GraphEngine:
"""
with flask_app.app_context():
try:
q.put(ParallelBranchRunStartedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
))
# run node
generator = self._run(
start_node_id=parallel_start_node_id,
@ -296,12 +301,23 @@ class GraphEngine:
q.put(item)
# trigger graph run success event
q.put(GraphRunSucceededEvent())
q.put(ParallelBranchRunSucceededEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
))
except GraphRunFailedError as e:
q.put(GraphRunFailedEvent(reason=e.error))
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
reason=e.error
))
except Exception as e:
logger.exception("Unknown Error when generating in parallel")
q.put(GraphRunFailedEvent(reason=str(e)))
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
reason=str(e)
))
finally:
db.session.remove()

View File

@ -1,5 +1,5 @@
import logging
from collections.abc import Mapping, Sequence
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
@ -11,7 +11,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import GraphRunFailedEvent
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.base_node import BaseNode
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class WorkflowEntry:
def run_workflow(
def run(
self,
*,
workflow: Workflow,
@ -37,7 +37,7 @@ class WorkflowEntry:
system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0
) -> None:
) -> Generator[GraphEngineEvent, None, None]:
"""
:param workflow: Workflow instance
:param user_id: user id
@ -110,6 +110,7 @@ class WorkflowEntry:
graph_runtime_state=graph_engine.graph_runtime_state,
event=event
)
yield event
except GenerateTaskStoppedException:
pass
except Exception as e:
@ -125,10 +126,10 @@ class WorkflowEntry:
)
return
def single_step_run_workflow_node(self, workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict) -> tuple[BaseNode, NodeRunResult]:
def single_step_run(self, workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict) -> tuple[BaseNode, NodeRunResult]:
"""
Single step run workflow node
:param workflow: Workflow instance

View File

@ -213,7 +213,7 @@ class WorkflowService:
start_at = time.perf_counter()
try:
node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node(
node_instance, node_run_result = workflow_engine_manager.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,

View File

@ -248,13 +248,13 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
)
)
print("")
# print("")
with patch.object(LLMNode, '_run', new=llm_generator):
items = []
generator = graph_engine.run()
for item in generator:
print(type(item), item)
# print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
@ -267,7 +267,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
]:
assert item.parallel_id is not None
assert len(items) == 17
assert len(items) == 21
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start'
@ -402,7 +402,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
]:
assert item.parallel_id is not None
assert len(items) == 19
assert len(items) == 23
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start'