mirror of https://github.com/langgenius/dify.git
add parallel branch events
This commit is contained in:
parent
483f71f03c
commit
63addf8c94
|
|
@ -98,7 +98,7 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow_engine_manager.run(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class WorkflowAppRunner:
|
|||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow_engine_manager.run(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
|
|
|
|||
|
|
@ -67,23 +67,24 @@ class NodeRunFailedEvent(BaseNodeEvent):
|
|||
|
||||
|
||||
###########################################
|
||||
# Parallel Events
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelEvent(GraphEngineEvent):
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
|
||||
|
||||
class ParallelRunStartedEvent(BaseParallelEvent):
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelRunSucceededEvent(BaseParallelEvent):
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelRunFailedEvent(BaseParallelEvent):
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
|
|
@ -113,4 +114,4 @@ class IterationRunFailedEvent(BaseIterationEvent):
|
|||
reason: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelEvent | BaseIterationEvent
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ from core.workflow.graph_engine.entities.event import (
|
|||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
|
|
@ -237,14 +240,12 @@ class GraphEngine:
|
|||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
run_thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'q': q
|
||||
})
|
||||
|
||||
run_thread.start()
|
||||
}).start()
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
|
|
@ -253,16 +254,15 @@ class GraphEngine:
|
|||
if event is None:
|
||||
break
|
||||
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
yield event
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(edge_mappings):
|
||||
break
|
||||
|
||||
continue
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.reason)
|
||||
else:
|
||||
yield event
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
|
|
@ -286,6 +286,11 @@ class GraphEngine:
|
|||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
q.put(ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
))
|
||||
|
||||
# run node
|
||||
generator = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
|
|
@ -296,12 +301,23 @@ class GraphEngine:
|
|||
q.put(item)
|
||||
|
||||
# trigger graph run success event
|
||||
q.put(GraphRunSucceededEvent())
|
||||
q.put(ParallelBranchRunSucceededEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
))
|
||||
except GraphRunFailedError as e:
|
||||
q.put(GraphRunFailedEvent(reason=e.error))
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
reason=e.error
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(GraphRunFailedEvent(reason=str(e)))
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
reason=str(e)
|
||||
))
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -11,7 +11,7 @@ from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
|||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import GraphRunFailedEvent
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
|
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class WorkflowEntry:
|
||||
def run_workflow(
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
|
|
@ -37,7 +37,7 @@ class WorkflowEntry:
|
|||
system_inputs: Mapping[SystemVariable, Any],
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
call_depth: int = 0
|
||||
) -> None:
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
|
|
@ -110,6 +110,7 @@ class WorkflowEntry:
|
|||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
yield event
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
@ -125,10 +126,10 @@ class WorkflowEntry:
|
|||
)
|
||||
return
|
||||
|
||||
def single_step_run_workflow_node(self, workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict) -> tuple[BaseNode, NodeRunResult]:
|
||||
def single_step_run(self, workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict) -> tuple[BaseNode, NodeRunResult]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
:param workflow: Workflow instance
|
||||
|
|
|
|||
|
|
@ -213,7 +213,7 @@ class WorkflowService:
|
|||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node(
|
||||
node_instance, node_run_result = workflow_engine_manager.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
|
|
|
|||
|
|
@ -248,13 +248,13 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
|||
)
|
||||
)
|
||||
|
||||
print("")
|
||||
# print("")
|
||||
|
||||
with patch.object(LLMNode, '_run', new=llm_generator):
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
print(type(item), item)
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
if isinstance(item, NodeRunSucceededEvent):
|
||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||
|
|
@ -267,7 +267,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
|||
]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 17
|
||||
assert len(items) == 21
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == 'start'
|
||||
|
|
@ -402,7 +402,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
|||
]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 19
|
||||
assert len(items) == 23
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == 'start'
|
||||
|
|
|
|||
Loading…
Reference in New Issue