refactor: streamline input handling and update type hints for event data structures

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-23 12:06:56 +08:00
parent 8b8801d43c
commit d4ddcda3f2
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
3 changed files with 56 additions and 49 deletions

View File

@ -189,6 +189,17 @@ class WorkflowBasedAppRunner(AppRunner):
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunRetryEvent): elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event( self._publish_event(
QueueNodeRetryEvent( QueueNodeRetryEvent(
node_execution_id=event.id, node_execution_id=event.id,
@ -204,19 +215,11 @@ class WorkflowBasedAppRunner(AppRunner):
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id, parallel_mode_run_id=event.parallel_mode_run_id,
inputs=event.route_node_state.node_run_result.inputs inputs=inputs,
if event.route_node_state.node_run_result process_data=process_data,
else {}, outputs=outputs,
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.error, error=event.error,
execution_metadata=event.route_node_state.node_run_result.metadata execution_metadata=execution_metadata,
if event.route_node_state.node_run_result
else {},
retry_index=event.retry_index, retry_index=event.retry_index,
) )
) )
@ -239,6 +242,17 @@ class WorkflowBasedAppRunner(AppRunner):
) )
) )
elif isinstance(event, NodeRunSucceededEvent): elif isinstance(event, NodeRunSucceededEvent):
node_run_result = event.route_node_state.node_run_result
if node_run_result:
inputs = node_run_result.inputs
process_data = node_run_result.process_data
outputs = node_run_result.outputs
execution_metadata = node_run_result.metadata
else:
inputs = {}
process_data = {}
outputs = {}
execution_metadata = {}
self._publish_event( self._publish_event(
QueueNodeSucceededEvent( QueueNodeSucceededEvent(
node_execution_id=event.id, node_execution_id=event.id,
@ -250,18 +264,10 @@ class WorkflowBasedAppRunner(AppRunner):
parent_parallel_id=event.parent_parallel_id, parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at, start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs inputs=inputs,
if event.route_node_state.node_run_result process_data=process_data,
else {}, outputs=outputs,
process_data=event.route_node_state.node_run_result.process_data execution_metadata=execution_metadata,
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id, in_iteration_id=event.in_iteration_id,
) )
) )

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
@ -85,9 +86,9 @@ class QueueIterationStartEvent(AppQueueEvent):
start_at: datetime start_at: datetime
node_run_index: int node_run_index: int
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[Mapping[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent): class QueueIterationNextEvent(AppQueueEvent):
@ -139,9 +140,9 @@ class QueueIterationCompletedEvent(AppQueueEvent):
start_at: datetime start_at: datetime
node_run_index: int node_run_index: int
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[dict[str, Any]] = None metadata: Optional[Mapping[str, Any]] = None
steps: int = 0 steps: int = 0
error: Optional[str] = None error: Optional[str] = None
@ -304,9 +305,9 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
start_at: datetime start_at: datetime
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None error: Optional[str] = None
@ -319,10 +320,10 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
event: QueueEvent = QueueEvent.RETRY event: QueueEvent = QueueEvent.RETRY
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str error: str
retry_index: int # retry index retry_index: int # retry index
@ -351,10 +352,10 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
start_at: datetime start_at: datetime
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str error: str
@ -382,10 +383,10 @@ class QueueNodeExceptionEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
start_at: datetime start_at: datetime
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str error: str
@ -413,10 +414,10 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
start_at: datetime start_at: datetime
inputs: Optional[dict[str, Any]] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: str error: str

View File

@ -33,7 +33,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason") error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0) exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent): class GraphRunPartialSucceededEvent(BaseGraphEvent):