add parallel branch output

This commit is contained in:
takatost 2024-07-25 19:39:06 +08:00
parent f4eb7cd037
commit 4097f7c069
4 changed files with 45 additions and 16 deletions

View File

@ -40,7 +40,10 @@ class GraphRunFailedEvent(BaseGraphEvent):
class BaseNodeEvent(GraphEngineEvent):
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")

View File

@ -115,6 +115,10 @@ class GraphEngine:
raise e
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
parallel_start_node_id = start_node_id
next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None
while True:
@ -139,7 +143,8 @@ class GraphEngine:
yield from self._run_node(
route_node_state=route_node_state,
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
parallel_id=in_parallel_id
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
@ -155,7 +160,8 @@ class GraphEngine:
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=in_parallel_id
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
raise e
@ -287,14 +293,16 @@ class GraphEngine:
def _run_node(self,
route_node_state: RouteNodeState,
previous_node_id: Optional[str] = None,
parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
"""
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
# get node config
@ -305,7 +313,8 @@ class GraphEngine:
route_node_state.failed_reason = f'Node {node_id} config not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
@ -317,7 +326,8 @@ class GraphEngine:
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
@ -344,8 +354,9 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
route_node_state=route_node_state
parallel_start_node_id=parallel_start_node_id
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
@ -365,24 +376,27 @@ class GraphEngine:
)
yield NodeRunSucceededEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
route_node_state=route_node_state
parallel_start_node_id=parallel_start_node_id
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
retriever_resources=item.retriever_resources,
context=item.context
parallel_start_node_id=parallel_start_node_id,
)
except GenerateTaskStoppedException:
# trigger node run failed event
@ -390,7 +404,8 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
return
except Exception as e:

View File

@ -140,6 +140,7 @@ class AnswerStreamProcessor:
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@ -162,6 +163,7 @@ class AnswerStreamProcessor:
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
self.route_position[answer_node_id] += 1

View File

@ -31,9 +31,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
start_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
parallel_id = graph.node_parallel_mapping.get(next_node_id)
parallel_start_node_id = None
if parallel_id:
parallel = graph.parallel_mapping.get(parallel_id)
parallel_start_node_id = parallel.start_from_node_id if parallel else None
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_start_node_id=parallel_start_node_id
)
if 'llm' in next_node_id:
@ -43,14 +50,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
chunk_content=str(i),
route_node_state=route_node_state,
from_variable_selector=[next_node_id, "text"],
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield NodeRunSucceededEvent(
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)