mirror of https://github.com/langgenius/dify.git
add parallel branch output
This commit is contained in:
parent
f4eb7cd037
commit
4097f7c069
|
|
@ -40,7 +40,10 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
|||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -115,6 +115,10 @@ class GraphEngine:
|
|||
raise e
|
||||
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_start_node_id = None
|
||||
if in_parallel_id:
|
||||
parallel_start_node_id = start_node_id
|
||||
|
||||
next_node_id = start_node_id
|
||||
previous_route_node_state: Optional[RouteNodeState] = None
|
||||
while True:
|
||||
|
|
@ -139,7 +143,8 @@ class GraphEngine:
|
|||
yield from self._run_node(
|
||||
route_node_state=route_node_state,
|
||||
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
|
||||
parallel_id=in_parallel_id
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
|
||||
|
|
@ -155,7 +160,8 @@ class GraphEngine:
|
|||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
@ -287,14 +293,16 @@ class GraphEngine:
|
|||
def _run_node(self,
|
||||
route_node_state: RouteNodeState,
|
||||
previous_node_id: Optional[str] = None,
|
||||
parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
# get node config
|
||||
|
|
@ -305,7 +313,8 @@ class GraphEngine:
|
|||
route_node_state.failed_reason = f'Node {node_id} config not found.'
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -317,7 +326,8 @@ class GraphEngine:
|
|||
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
return
|
||||
|
||||
|
|
@ -344,8 +354,9 @@ class GraphEngine:
|
|||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
route_node_state=route_node_state
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
|
|
@ -365,24 +376,27 @@ class GraphEngine:
|
|||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
route_node_state=route_node_state
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
|
|
@ -390,7 +404,8 @@ class GraphEngine:
|
|||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -140,6 +140,7 @@ class AnswerStreamProcessor:
|
|||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
|
|
@ -162,6 +163,7 @@ class AnswerStreamProcessor:
|
|||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
|
|
|||
|
|
@ -31,9 +31,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||
start_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
parallel_id = graph.node_parallel_mapping.get(next_node_id)
|
||||
parallel_start_node_id = None
|
||||
if parallel_id:
|
||||
parallel = graph.parallel_mapping.get(parallel_id)
|
||||
parallel_start_node_id = parallel.start_from_node_id if parallel else None
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
if 'llm' in next_node_id:
|
||||
|
|
@ -43,14 +50,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||
chunk_content=str(i),
|
||||
route_node_state=route_node_state,
|
||||
from_variable_selector=[next_node_id, "text"],
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
yield NodeRunSucceededEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue