mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 11:06:46 +08:00
chore: merge feat/queue-based-graph-engine (#25833)
This commit is contained in:
commit
e6d65fe356
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -43,6 +43,10 @@ jobs:
|
|||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: uv sync --project api --dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
|
- name: Run Import Linter
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
run: uv run --directory api --dev lint-imports
|
||||||
|
|
||||||
- name: Run Basedpyright Checks
|
- name: Run Basedpyright Checks
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: dev/basedpyright-check
|
run: dev/basedpyright-check
|
||||||
|
|||||||
@ -22,14 +22,15 @@ containers =
|
|||||||
ignore_imports =
|
ignore_imports =
|
||||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||||
|
|
||||||
|
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
|
||||||
|
|
||||||
[importlinter:contract:rsc]
|
[importlinter:contract:rsc]
|
||||||
name = RSC
|
name = RSC
|
||||||
@ -57,9 +58,9 @@ layers =
|
|||||||
orchestration
|
orchestration
|
||||||
command_processing
|
command_processing
|
||||||
event_management
|
event_management
|
||||||
error_handling
|
error_handler
|
||||||
graph_traversal
|
graph_traversal
|
||||||
state_management
|
graph_state_manager
|
||||||
worker_management
|
worker_management
|
||||||
domain
|
domain
|
||||||
containers =
|
containers =
|
||||||
@ -86,14 +87,6 @@ forbidden_modules =
|
|||||||
core.workflow.graph_engine.command_processing
|
core.workflow.graph_engine.command_processing
|
||||||
core.workflow.graph_engine.event_management
|
core.workflow.graph_engine.event_management
|
||||||
|
|
||||||
[importlinter:contract:error-handling-strategies]
|
|
||||||
name = Error Handling Strategies
|
|
||||||
type = independence
|
|
||||||
modules =
|
|
||||||
core.workflow.graph_engine.error_handling.abort_strategy
|
|
||||||
core.workflow.graph_engine.error_handling.retry_strategy
|
|
||||||
core.workflow.graph_engine.error_handling.fail_branch_strategy
|
|
||||||
core.workflow.graph_engine.error_handling.default_value_strategy
|
|
||||||
|
|
||||||
[importlinter:contract:graph-traversal-components]
|
[importlinter:contract:graph-traversal-components]
|
||||||
name = Graph Traversal Components
|
name = Graph Traversal Components
|
||||||
|
|||||||
@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# Returns whether the obtained value is obtained, and None if it does not
|
# Returns whether the obtained value is obtained, and None if it does not
|
||||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any:
|
||||||
if namespace_cache:
|
if namespace_cache:
|
||||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||||
if kv_data is None:
|
if kv_data is None:
|
||||||
|
|||||||
@ -355,7 +355,7 @@ class WorkflowResponseConverter:
|
|||||||
else WorkflowNodeExecutionStatus.FAILED,
|
else WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=None,
|
error=None,
|
||||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||||
execution_metadata=event.metadata,
|
execution_metadata=event.metadata,
|
||||||
finished_at=int(time.time()),
|
finished_at=int(time.time()),
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
@ -442,7 +442,7 @@ class WorkflowResponseConverter:
|
|||||||
else WorkflowNodeExecutionStatus.FAILED,
|
else WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=None,
|
error=None,
|
||||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||||
execution_metadata=event.metadata,
|
execution_metadata=event.metadata,
|
||||||
finished_at=int(time.time()),
|
finished_at=int(time.time()),
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
|
|||||||
@ -384,7 +384,6 @@ class WorkflowBasedAppRunner:
|
|||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
@ -406,7 +405,6 @@ class WorkflowBasedAppRunner:
|
|||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
agent_strategy=event.agent_strategy,
|
agent_strategy=event.agent_strategy,
|
||||||
provider_type=event.provider_type,
|
provider_type=event.provider_type,
|
||||||
provider_id=event.provider_id,
|
provider_id=event.provider_id,
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
@ -79,9 +79,9 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueIterationNextEvent(AppQueueEvent):
|
class QueueIterationNextEvent(AppQueueEvent):
|
||||||
@ -97,7 +97,7 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_title: str
|
node_title: str
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
output: Optional[Any] = None # output for the current iteration
|
output: Any = None # output for the current iteration
|
||||||
|
|
||||||
|
|
||||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||||
@ -114,9 +114,9 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
@ -143,9 +143,9 @@ class QueueLoopStartEvent(AppQueueEvent):
|
|||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueLoopNextEvent(AppQueueEvent):
|
class QueueLoopNextEvent(AppQueueEvent):
|
||||||
@ -171,7 +171,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
|||||||
parallel_mode_run_id: str | None = None
|
parallel_mode_run_id: str | None = None
|
||||||
"""iteration run in parallel mode run id"""
|
"""iteration run in parallel mode run id"""
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
output: Optional[Any] = None # output for the current loop
|
output: Any = None # output for the current loop
|
||||||
|
|
||||||
|
|
||||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||||
@ -185,7 +185,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
|||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_title: str
|
node_title: str
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
"""parallel start node id if node is in parallel"""
|
"""parallel start node id if node is in parallel"""
|
||||||
@ -196,9 +196,9 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
|||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
@ -299,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||||
outputs: dict[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||||
@ -319,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
|||||||
|
|
||||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||||
exceptions_count: int
|
exceptions_count: int
|
||||||
outputs: dict[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeStartedEvent(AppQueueEvent):
|
class QueueNodeStartedEvent(AppQueueEvent):
|
||||||
@ -334,16 +334,16 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||||||
node_title: str
|
node_title: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: str | None = None
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: str | None = None
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: str | None = None
|
||||||
parent_parallel_id: Optional[str] = None
|
parent_parallel_id: str | None = None
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
parent_parallel_start_node_id: str | None = None
|
||||||
in_iteration_id: Optional[str] = None
|
in_iteration_id: str | None = None
|
||||||
in_loop_id: Optional[str] = None
|
in_loop_id: str | None = None
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
parallel_mode_run_id: Optional[str] = None
|
parallel_mode_run_id: str | None = None
|
||||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
agent_strategy: AgentNodeStrategyInit | None = None
|
||||||
|
|
||||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||||
@ -360,7 +360,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
"""parallel start node id if node is in parallel"""
|
"""parallel start node id if node is in parallel"""
|
||||||
@ -374,12 +374,12 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: Optional[str] = None
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class QueueAgentLogEvent(AppQueueEvent):
|
class QueueAgentLogEvent(AppQueueEvent):
|
||||||
@ -395,7 +395,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
status: str
|
status: str
|
||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
node_id: str
|
node_id: str
|
||||||
|
|
||||||
|
|
||||||
@ -404,9 +404,9 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
|||||||
|
|
||||||
event: QueueEvent = QueueEvent.RETRY
|
event: QueueEvent = QueueEvent.RETRY
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
@ -423,7 +423,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
"""parallel start node id if node is in parallel"""
|
"""parallel start node id if node is in parallel"""
|
||||||
@ -437,9 +437,9 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
@ -455,16 +455,16 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: str | None = None
|
||||||
in_iteration_id: Optional[str] = None
|
in_iteration_id: str | None = None
|
||||||
"""iteration id if node is in iteration"""
|
"""iteration id if node is in iteration"""
|
||||||
in_loop_id: str | None = None
|
in_loop_id: str | None = None
|
||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
@ -494,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.ERROR
|
event: QueueEvent = QueueEvent.ERROR
|
||||||
error: Any | None = None
|
error: Any = None
|
||||||
|
|
||||||
|
|
||||||
class QueuePingEvent(AppQueueEvent):
|
class QueuePingEvent(AppQueueEvent):
|
||||||
|
|||||||
@ -136,7 +136,7 @@ class MessageEndStreamResponse(StreamResponse):
|
|||||||
|
|
||||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||||
id: str
|
id: str
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
files: Sequence[Mapping[str, Any]] | None = None
|
files: Sequence[Mapping[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@ -173,7 +173,7 @@ class AgentThoughtStreamResponse(StreamResponse):
|
|||||||
thought: str | None = None
|
thought: str | None = None
|
||||||
observation: str | None = None
|
observation: str | None = None
|
||||||
tool: str | None = None
|
tool: str | None = None
|
||||||
tool_labels: dict | None = None
|
tool_labels: Mapping[str, object] = Field(default_factory=dict)
|
||||||
tool_input: str | None = None
|
tool_input: str | None = None
|
||||||
message_files: list[str] | None = None
|
message_files: list[str] | None = None
|
||||||
|
|
||||||
@ -226,7 +226,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
total_steps: int
|
total_steps: int
|
||||||
created_by: dict | None = None
|
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||||
created_at: int
|
created_at: int
|
||||||
finished_at: int
|
finished_at: int
|
||||||
exceptions_count: int | None = 0
|
exceptions_count: int | None = 0
|
||||||
@ -256,7 +256,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
inputs_truncated: bool = False
|
inputs_truncated: bool = False
|
||||||
created_at: int
|
created_at: int
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: dict[str, object] = Field(default_factory=dict)
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
parent_parallel_id: str | None = None
|
parent_parallel_id: str | None = None
|
||||||
@ -513,7 +513,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
execution_metadata: Mapping | None = None
|
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
finished_at: int
|
finished_at: int
|
||||||
steps: int
|
steps: int
|
||||||
|
|
||||||
@ -565,11 +565,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
|||||||
title: str
|
title: str
|
||||||
index: int
|
index: int
|
||||||
created_at: int
|
created_at: int
|
||||||
pre_loop_output: Any | None = None
|
pre_loop_output: Any = None
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: str | None = None
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: str | None = None
|
||||||
parallel_mode_run_id: Optional[str] = None
|
parallel_mode_run_id: str | None = None
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
@ -600,7 +600,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
execution_metadata: Mapping | None = None
|
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
finished_at: int
|
finished_at: int
|
||||||
steps: int
|
steps: int
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
@ -710,7 +710,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
|||||||
conversation_id: str
|
conversation_id: str
|
||||||
message_id: str
|
message_id: str
|
||||||
answer: str
|
answer: str
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
created_at: int
|
created_at: int
|
||||||
|
|
||||||
data: Data
|
data: Data
|
||||||
@ -730,7 +730,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
|||||||
mode: str
|
mode: str
|
||||||
message_id: str
|
message_id: str
|
||||||
answer: str
|
answer: str
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
created_at: int
|
created_at: int
|
||||||
|
|
||||||
data: Data
|
data: Data
|
||||||
@ -778,7 +778,7 @@ class AgentLogStreamResponse(StreamResponse):
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
status: str
|
status: str
|
||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
node_id: str
|
node_id: str
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||||
|
|||||||
@ -109,7 +109,9 @@ class AppGeneratorTTSPublisher:
|
|||||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||||
if message.event.outputs is None:
|
if message.event.outputs is None:
|
||||||
continue
|
continue
|
||||||
self.msg_text += message.event.outputs.get("output", "")
|
output = message.event.outputs.get("output", "")
|
||||||
|
if isinstance(output, str):
|
||||||
|
self.msg_text += output
|
||||||
self.last_message = message
|
self.last_message = message
|
||||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||||
if len(sentence_arr) >= min(self.max_sentence, 7):
|
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||||
@ -119,7 +121,7 @@ class AppGeneratorTTSPublisher:
|
|||||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||||
)
|
)
|
||||||
future_queue.put(futures_result)
|
future_queue.put(futures_result)
|
||||||
if text_tmp:
|
if isinstance(text_tmp, str):
|
||||||
self.msg_text = text_tmp
|
self.msg_text = text_tmp
|
||||||
else:
|
else:
|
||||||
self.msg_text = ""
|
self.msg_text = ""
|
||||||
|
|||||||
@ -1,9 +1,33 @@
|
|||||||
from abc import abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class CodeNodeProvider(BaseModel):
|
class VariableConfig(TypedDict):
|
||||||
|
variable: str
|
||||||
|
value_selector: Sequence[str | int]
|
||||||
|
|
||||||
|
|
||||||
|
class OutputConfig(TypedDict):
|
||||||
|
type: str
|
||||||
|
children: None
|
||||||
|
|
||||||
|
|
||||||
|
class CodeConfig(TypedDict):
|
||||||
|
variables: Sequence[VariableConfig]
|
||||||
|
code_language: str
|
||||||
|
code: str
|
||||||
|
outputs: Mapping[str, OutputConfig]
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultConfig(TypedDict):
|
||||||
|
type: str
|
||||||
|
config: CodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CodeNodeProvider(BaseModel, ABC):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_language() -> str:
|
def get_language() -> str:
|
||||||
@ -22,11 +46,14 @@ class CodeNodeProvider(BaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls):
|
def get_default_config(cls) -> DefaultConfig:
|
||||||
return {
|
return {
|
||||||
"type": "code",
|
"type": "code",
|
||||||
"config": {
|
"config": {
|
||||||
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
|
"variables": [
|
||||||
|
{"variable": "arg1", "value_selector": []},
|
||||||
|
{"variable": "arg2", "value_selector": []},
|
||||||
|
],
|
||||||
"code_language": cls.get_language(),
|
"code_language": cls.get_language(),
|
||||||
"code": cls.get_default_code(),
|
"code": cls.get_default_code(),
|
||||||
"outputs": {"result": {"type": "string", "children": None}},
|
"outputs": {"result": {"type": "string", "children": None}},
|
||||||
|
|||||||
@ -160,7 +160,7 @@ class ErrorData(BaseModel):
|
|||||||
sentence.
|
sentence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: Any | None = None
|
data: Any = None
|
||||||
"""
|
"""
|
||||||
Additional information about the error. The value of this member is defined by the
|
Additional information about the error. The value of this member is defined by the
|
||||||
sender (e.g. detailed error information, nested errors etc.).
|
sender (e.g. detailed error information, nested errors etc.).
|
||||||
|
|||||||
@ -22,13 +22,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelProviderFactory:
|
class ModelProviderFactory:
|
||||||
provider_position_map: dict[str, int]
|
|
||||||
|
|
||||||
def __init__(self, tenant_id: str):
|
def __init__(self, tenant_id: str):
|
||||||
from core.plugin.impl.model import PluginModelClient
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
self.provider_position_map = {}
|
|
||||||
|
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.plugin_model_manager = PluginModelClient()
|
self.plugin_model_manager = PluginModelClient()
|
||||||
|
|
||||||
|
|||||||
@ -408,11 +408,11 @@ class TraceTask:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
trace_type: Any,
|
trace_type: Any,
|
||||||
message_id: Optional[str] = None,
|
message_id: str | None = None,
|
||||||
workflow_execution: Optional["WorkflowExecution"] = None,
|
workflow_execution: Optional["WorkflowExecution"] = None,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: str | None = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: str | None = None,
|
||||||
timer: Optional[Any] = None,
|
timer: Any | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.trace_type = trace_type
|
self.trace_type = trace_type
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
@ -16,10 +17,10 @@ class ToolApiEntity(BaseModel):
|
|||||||
description: I18nObject
|
description: I18nObject
|
||||||
parameters: list[ToolParameter] | None = None
|
parameters: list[ToolParameter] | None = None
|
||||||
labels: list[str] = Field(default_factory=list)
|
labels: list[str] = Field(default_factory=list)
|
||||||
output_schema: dict | None = None
|
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
|
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderApiEntity(BaseModel):
|
class ToolProviderApiEntity(BaseModel):
|
||||||
@ -27,17 +28,17 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
author: str
|
author: str
|
||||||
name: str # identifier
|
name: str # identifier
|
||||||
description: I18nObject
|
description: I18nObject
|
||||||
icon: str | dict
|
icon: str | Mapping[str, str]
|
||||||
icon_dark: str | dict | None = Field(default=None, description="The dark icon of the tool")
|
icon_dark: str | Mapping[str, str] = ""
|
||||||
label: I18nObject # label
|
label: I18nObject # label
|
||||||
type: ToolProviderType
|
type: ToolProviderType
|
||||||
masked_credentials: dict | None = None
|
masked_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||||
original_credentials: dict | None = None
|
original_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||||
is_team_authorization: bool = False
|
is_team_authorization: bool = False
|
||||||
allow_delete: bool = True
|
allow_delete: bool = True
|
||||||
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
|
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
|
||||||
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
|
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
|
||||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity])
|
||||||
labels: list[str] = Field(default_factory=list)
|
labels: list[str] = Field(default_factory=list)
|
||||||
# MCP
|
# MCP
|
||||||
server_url: str | None = Field(default="", description="The server url of the tool")
|
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||||
@ -105,7 +106,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
|
|||||||
is_default: bool = Field(
|
is_default: bool = Field(
|
||||||
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
||||||
)
|
)
|
||||||
credentials: dict = Field(description="The credentials of the provider")
|
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||||
|
|||||||
@ -187,7 +187,7 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
error: str | None = Field(default=None, description="The error message")
|
error: str | None = Field(default=None, description="The error message")
|
||||||
status: LogStatus = Field(..., description="The status of the log")
|
status: LogStatus = Field(..., description="The status of the log")
|
||||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||||
metadata: Mapping[str, Any] | None = Field(default=None, description="The metadata of the log")
|
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
|
||||||
|
|
||||||
class RetrieverResourceMessage(BaseModel):
|
class RetrieverResourceMessage(BaseModel):
|
||||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||||
@ -363,9 +363,9 @@ class ToolDescription(BaseModel):
|
|||||||
|
|
||||||
class ToolEntity(BaseModel):
|
class ToolEntity(BaseModel):
|
||||||
identity: ToolIdentity
|
identity: ToolIdentity
|
||||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
parameters: list[ToolParameter] = Field(default_factory=list[ToolParameter])
|
||||||
description: ToolDescription | None = None
|
description: ToolDescription | None = None
|
||||||
output_schema: dict | None = None
|
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||||
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
||||||
|
|
||||||
# pydantic configs
|
# pydantic configs
|
||||||
@ -378,21 +378,23 @@ class ToolEntity(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class OAuthSchema(BaseModel):
|
class OAuthSchema(BaseModel):
|
||||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
client_schema: list[ProviderConfig] = Field(
|
||||||
|
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
|
||||||
|
)
|
||||||
credentials_schema: list[ProviderConfig] = Field(
|
credentials_schema: list[ProviderConfig] = Field(
|
||||||
default_factory=list, description="The schema of the OAuth credentials"
|
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderEntity(BaseModel):
|
class ToolProviderEntity(BaseModel):
|
||||||
identity: ToolProviderIdentity
|
identity: ToolProviderIdentity
|
||||||
plugin_id: str | None = None
|
plugin_id: str | None = None
|
||||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
credentials_schema: list[ProviderConfig] = Field(default_factory=list[ProviderConfig])
|
||||||
oauth_schema: OAuthSchema | None = None
|
oauth_schema: OAuthSchema | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||||
tools: list[ToolEntity] = Field(default_factory=list)
|
tools: list[ToolEntity] = Field(default_factory=list[ToolEntity])
|
||||||
|
|
||||||
|
|
||||||
class WorkflowToolParameterConfiguration(BaseModel):
|
class WorkflowToolParameterConfiguration(BaseModel):
|
||||||
|
|||||||
@ -72,7 +72,6 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
),
|
),
|
||||||
llm=remote_mcp_tool.description or "",
|
llm=remote_mcp_tool.description or "",
|
||||||
),
|
),
|
||||||
output_schema=None,
|
|
||||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||||
)
|
)
|
||||||
for remote_mcp_tool in remote_mcp_tools
|
for remote_mcp_tool in remote_mcp_tools
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from collections.abc import Generator, Iterable
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
@ -152,9 +152,9 @@ class ToolEngine:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||||
workflow_call_depth: int,
|
workflow_call_depth: int,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: str | None = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: str | None = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: str | None = None,
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Workflow invokes the tool with the given arguments.
|
Workflow invokes the tool with the given arguments.
|
||||||
|
|||||||
@ -14,31 +14,17 @@ from sqlalchemy.orm import Session
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
|
||||||
from core.tools.mcp_tool.tool import MCPTool
|
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
|
||||||
from core.tools.plugin_tool.tool import PluginTool
|
|
||||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from models.provider_ids import ToolProviderID
|
|
||||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.workflow.nodes.tool.entities import ToolEntity
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
from core.tools.builtin_tool.tool import BuiltinTool
|
||||||
@ -54,12 +40,21 @@ from core.tools.entities.tool_entities import (
|
|||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
|
from core.tools.mcp_tool.tool import MCPTool
|
||||||
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
|
from core.tools.plugin_tool.tool import PluginTool
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||||
|
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||||
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.provider_ids import ToolProviderID
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||||
|
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||||
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -890,7 +885,7 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||||
try:
|
try:
|
||||||
workflow_provider: WorkflowToolProvider | None = (
|
workflow_provider: WorkflowToolProvider | None = (
|
||||||
db.session.query(WorkflowToolProvider)
|
db.session.query(WorkflowToolProvider)
|
||||||
@ -901,13 +896,13 @@ class ToolManager:
|
|||||||
if workflow_provider is None:
|
if workflow_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||||
|
|
||||||
icon: dict = json.loads(workflow_provider.icon)
|
icon = json.loads(workflow_provider.icon)
|
||||||
return icon
|
return icon
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||||
try:
|
try:
|
||||||
api_provider: ApiToolProvider | None = (
|
api_provider: ApiToolProvider | None = (
|
||||||
db.session.query(ApiToolProvider)
|
db.session.query(ApiToolProvider)
|
||||||
@ -918,13 +913,13 @@ class ToolManager:
|
|||||||
if api_provider is None:
|
if api_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||||
|
|
||||||
icon: dict = json.loads(api_provider.icon)
|
icon = json.loads(api_provider.icon)
|
||||||
return icon
|
return icon
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
|
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||||
try:
|
try:
|
||||||
mcp_provider: MCPToolProvider | None = (
|
mcp_provider: MCPToolProvider | None = (
|
||||||
db.session.query(MCPToolProvider)
|
db.session.query(MCPToolProvider)
|
||||||
@ -945,7 +940,7 @@ class ToolManager:
|
|||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
provider_type: ToolProviderType,
|
provider_type: ToolProviderType,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
) -> Union[str, dict[str, Any]]:
|
) -> str | Mapping[str, str]:
|
||||||
"""
|
"""
|
||||||
get the tool icon
|
get the tool icon
|
||||||
|
|
||||||
@ -970,11 +965,10 @@ class ToolManager:
|
|||||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||||
elif provider_type == ToolProviderType.PLUGIN:
|
elif provider_type == ToolProviderType.PLUGIN:
|
||||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||||
if isinstance(provider, PluginToolProviderController):
|
try:
|
||||||
try:
|
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
except Exception:
|
||||||
except Exception:
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
|
||||||
raise ValueError(f"plugin provider {provider_id} not found")
|
raise ValueError(f"plugin provider {provider_id} not found")
|
||||||
elif provider_type == ToolProviderType.MCP:
|
elif provider_type == ToolProviderType.MCP:
|
||||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||||
|
|||||||
132
api/core/workflow/README.md
Normal file
132
api/core/workflow/README.md
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
# Workflow
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
The graph engine follows a layered architecture with strict dependency rules:
|
||||||
|
|
||||||
|
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
|
||||||
|
|
||||||
|
- **Manager** - External control interface for stop/pause/resume commands
|
||||||
|
- **Worker** - Node execution runtime
|
||||||
|
- **Command Processing** - Handles control commands (abort, pause, resume)
|
||||||
|
- **Event Management** - Event propagation and layer notifications
|
||||||
|
- **Graph Traversal** - Edge processing and skip propagation
|
||||||
|
- **Response Coordinator** - Path tracking and session management
|
||||||
|
- **Layers** - Pluggable middleware (debug logging, execution limits)
|
||||||
|
- **Command Channels** - Communication channels (InMemory, Redis)
|
||||||
|
|
||||||
|
1. **Graph** (`graph/`) - Graph structure and runtime state
|
||||||
|
|
||||||
|
- **Graph Template** - Workflow definition
|
||||||
|
- **Edge** - Node connections with conditions
|
||||||
|
- **Runtime State Protocol** - State management interface
|
||||||
|
|
||||||
|
1. **Nodes** (`nodes/`) - Node implementations
|
||||||
|
|
||||||
|
- **Base** - Abstract node classes and variable parsing
|
||||||
|
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
|
||||||
|
|
||||||
|
1. **Events** (`node_events/`) - Event system
|
||||||
|
|
||||||
|
- **Base** - Event protocols
|
||||||
|
- **Node Events** - Node lifecycle events
|
||||||
|
|
||||||
|
1. **Entities** (`entities/`) - Domain models
|
||||||
|
|
||||||
|
- **Variable Pool** - Variable storage
|
||||||
|
- **Graph Init Params** - Initialization configuration
|
||||||
|
|
||||||
|
## Key Design Patterns
|
||||||
|
|
||||||
|
### Command Channel Pattern
|
||||||
|
|
||||||
|
External workflow control via Redis or in-memory channels:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Send stop command to running workflow
|
||||||
|
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
|
||||||
|
channel.send_command(AbortCommand(reason="User requested"))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Layer System
|
||||||
|
|
||||||
|
Extensible middleware for cross-cutting concerns:
|
||||||
|
|
||||||
|
```python
|
||||||
|
engine = GraphEngine(graph)
|
||||||
|
engine.add_layer(DebugLoggingLayer(level="INFO"))
|
||||||
|
engine.add_layer(ExecutionLimitsLayer(max_nodes=100))
|
||||||
|
```
|
||||||
|
|
||||||
|
### Event-Driven Architecture
|
||||||
|
|
||||||
|
All node executions emit events for monitoring and integration:
|
||||||
|
|
||||||
|
- `NodeRunStartedEvent` - Node execution begins
|
||||||
|
- `NodeRunSucceededEvent` - Node completes successfully
|
||||||
|
- `NodeRunFailedEvent` - Node encounters error
|
||||||
|
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
|
||||||
|
|
||||||
|
### Variable Pool
|
||||||
|
|
||||||
|
Centralized variable storage with namespace isolation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Variables scoped by node_id
|
||||||
|
pool.add(["node1", "output"], value)
|
||||||
|
result = pool.get(["node1", "output"])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Import Architecture Rules
|
||||||
|
|
||||||
|
The codebase enforces strict layering via import-linter:
|
||||||
|
|
||||||
|
1. **Workflow Layers** (top to bottom):
|
||||||
|
|
||||||
|
- graph_engine → graph_events → graph → nodes → node_events → entities
|
||||||
|
|
||||||
|
1. **Graph Engine Internal Layers**:
|
||||||
|
|
||||||
|
- orchestration → command_processing → event_management → graph_traversal → domain
|
||||||
|
|
||||||
|
1. **Domain Isolation**:
|
||||||
|
|
||||||
|
- Domain models cannot import from infrastructure layers
|
||||||
|
|
||||||
|
1. **Command Channel Independence**:
|
||||||
|
|
||||||
|
- InMemory and Redis channels must remain independent
|
||||||
|
|
||||||
|
## Common Tasks
|
||||||
|
|
||||||
|
### Adding a New Node Type
|
||||||
|
|
||||||
|
1. Create node class in `nodes/<node_type>/`
|
||||||
|
1. Inherit from `BaseNode` or appropriate base class
|
||||||
|
1. Implement `_run()` method
|
||||||
|
1. Register in `nodes/node_mapping.py`
|
||||||
|
1. Add tests in `tests/unit_tests/core/workflow/nodes/`
|
||||||
|
|
||||||
|
### Implementing a Custom Layer
|
||||||
|
|
||||||
|
1. Create class inheriting from `Layer` base
|
||||||
|
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
|
||||||
|
1. Add to engine via `engine.add_layer()`
|
||||||
|
|
||||||
|
### Debugging Workflow Execution
|
||||||
|
|
||||||
|
Enable debug logging layer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
debug_layer = DebugLoggingLayer(
|
||||||
|
level="DEBUG",
|
||||||
|
include_inputs=True,
|
||||||
|
include_outputs=True
|
||||||
|
)
|
||||||
|
```
|
||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@ -7,4 +5,4 @@ class AgentNodeStrategyInit(BaseModel):
|
|||||||
"""Agent node strategy initialization data."""
|
"""Agent node strategy initialization data."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
icon: Optional[str] = None
|
icon: str | None = None
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, PrivateAttr
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
@ -14,17 +13,24 @@ class GraphRuntimeState(BaseModel):
|
|||||||
_start_at: float = PrivateAttr()
|
_start_at: float = PrivateAttr()
|
||||||
_total_tokens: int = PrivateAttr(default=0)
|
_total_tokens: int = PrivateAttr(default=0)
|
||||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||||
_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
||||||
_node_run_steps: int = PrivateAttr(default=0)
|
_node_run_steps: int = PrivateAttr(default=0)
|
||||||
|
_ready_queue_json: str = PrivateAttr()
|
||||||
|
_graph_execution_json: str = PrivateAttr()
|
||||||
|
_response_coordinator_json: str = PrivateAttr()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
start_at: float,
|
start_at: float,
|
||||||
total_tokens: int = 0,
|
total_tokens: int = 0,
|
||||||
llm_usage: LLMUsage | None = None,
|
llm_usage: LLMUsage | None = None,
|
||||||
outputs: dict[str, Any] | None = None,
|
outputs: dict[str, object] | None = None,
|
||||||
node_run_steps: int = 0,
|
node_run_steps: int = 0,
|
||||||
|
ready_queue_json: str = "",
|
||||||
|
graph_execution_json: str = "",
|
||||||
|
response_coordinator_json: str = "",
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
"""Initialize the GraphRuntimeState with validation."""
|
"""Initialize the GraphRuntimeState with validation."""
|
||||||
@ -51,6 +57,10 @@ class GraphRuntimeState(BaseModel):
|
|||||||
raise ValueError("node_run_steps must be non-negative")
|
raise ValueError("node_run_steps must be non-negative")
|
||||||
self._node_run_steps = node_run_steps
|
self._node_run_steps = node_run_steps
|
||||||
|
|
||||||
|
self._ready_queue_json = ready_queue_json
|
||||||
|
self._graph_execution_json = graph_execution_json
|
||||||
|
self._response_coordinator_json = response_coordinator_json
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variable_pool(self) -> VariablePool:
|
def variable_pool(self) -> VariablePool:
|
||||||
"""Get the variable pool."""
|
"""Get the variable pool."""
|
||||||
@ -90,24 +100,24 @@ class GraphRuntimeState(BaseModel):
|
|||||||
self._llm_usage = value.model_copy()
|
self._llm_usage = value.model_copy()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outputs(self) -> dict[str, Any]:
|
def outputs(self) -> dict[str, object]:
|
||||||
"""Get a copy of the outputs dictionary."""
|
"""Get a copy of the outputs dictionary."""
|
||||||
return deepcopy(self._outputs)
|
return deepcopy(self._outputs)
|
||||||
|
|
||||||
@outputs.setter
|
@outputs.setter
|
||||||
def outputs(self, value: dict[str, Any]) -> None:
|
def outputs(self, value: dict[str, object]) -> None:
|
||||||
"""Set the outputs dictionary."""
|
"""Set the outputs dictionary."""
|
||||||
self._outputs = deepcopy(value)
|
self._outputs = deepcopy(value)
|
||||||
|
|
||||||
def set_output(self, key: str, value: Any) -> None:
|
def set_output(self, key: str, value: object) -> None:
|
||||||
"""Set a single output value."""
|
"""Set a single output value."""
|
||||||
self._outputs[key] = deepcopy(value)
|
self._outputs[key] = deepcopy(value)
|
||||||
|
|
||||||
def get_output(self, key: str, default: Any = None) -> Any:
|
def get_output(self, key: str, default: object = None) -> object:
|
||||||
"""Get a single output value."""
|
"""Get a single output value."""
|
||||||
return deepcopy(self._outputs.get(key, default))
|
return deepcopy(self._outputs.get(key, default))
|
||||||
|
|
||||||
def update_outputs(self, updates: dict[str, Any]) -> None:
|
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||||
"""Update multiple output values."""
|
"""Update multiple output values."""
|
||||||
for key, value in updates.items():
|
for key, value in updates.items():
|
||||||
self._outputs[key] = deepcopy(value)
|
self._outputs[key] = deepcopy(value)
|
||||||
@ -133,3 +143,18 @@ class GraphRuntimeState(BaseModel):
|
|||||||
if tokens < 0:
|
if tokens < 0:
|
||||||
raise ValueError("tokens must be non-negative")
|
raise ValueError("tokens must be non-negative")
|
||||||
self._total_tokens += tokens
|
self._total_tokens += tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ready_queue_json(self) -> str:
|
||||||
|
"""Get a copy of the ready queue state."""
|
||||||
|
return self._ready_queue_json
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph_execution_json(self) -> str:
|
||||||
|
"""Get a copy of the serialized graph execution state."""
|
||||||
|
return self._graph_execution_json
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response_coordinator_json(self) -> str:
|
||||||
|
"""Get a copy of the serialized response coordinator state."""
|
||||||
|
return self._response_coordinator_json
|
||||||
|
|||||||
@ -188,9 +188,9 @@ class Graph:
|
|||||||
for node_id, node_config in node_configs_map.items():
|
for node_id, node_config in node_configs_map.items():
|
||||||
try:
|
try:
|
||||||
node_instance = node_factory.create_node(node_config)
|
node_instance = node_factory.create_node(node_config)
|
||||||
except ValueError as e:
|
except Exception:
|
||||||
logger.warning("Failed to create node instance: %s", str(e))
|
logger.exception("Failed to create node instance for node_id %s", node_id)
|
||||||
continue
|
raise
|
||||||
nodes[node_id] = node_instance
|
nodes[node_id] = node_instance
|
||||||
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|||||||
@ -97,8 +97,12 @@ class RedisChannel:
|
|||||||
Returns:
|
Returns:
|
||||||
Deserialized command or None if invalid
|
Deserialized command or None if invalid
|
||||||
"""
|
"""
|
||||||
|
command_type_value = data.get("command_type")
|
||||||
|
if not isinstance(command_type_value, str):
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
command_type = CommandType(data.get("command_type"))
|
command_type = CommandType(command_type_value)
|
||||||
|
|
||||||
if command_type == CommandType.ABORT:
|
if command_type == CommandType.ABORT:
|
||||||
return AbortCommand(**data)
|
return AbortCommand(**data)
|
||||||
|
|||||||
@ -5,12 +5,10 @@ This package contains the core domain entities, value objects, and aggregates
|
|||||||
that represent the business concepts of workflow graph execution.
|
that represent the business concepts of workflow graph execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .execution_context import ExecutionContext
|
|
||||||
from .graph_execution import GraphExecution
|
from .graph_execution import GraphExecution
|
||||||
from .node_execution import NodeExecution
|
from .node_execution import NodeExecution
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ExecutionContext",
|
|
||||||
"GraphExecution",
|
"GraphExecution",
|
||||||
"NodeExecution",
|
"NodeExecution",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,37 +0,0 @@
|
|||||||
"""
|
|
||||||
ExecutionContext value object containing immutable execution parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ExecutionContext:
|
|
||||||
"""
|
|
||||||
Immutable value object containing the context for a graph execution.
|
|
||||||
|
|
||||||
This encapsulates all the contextual information needed to execute a workflow,
|
|
||||||
keeping it separate from the mutable execution state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tenant_id: str
|
|
||||||
app_id: str
|
|
||||||
workflow_id: str
|
|
||||||
user_id: str
|
|
||||||
user_from: UserFrom
|
|
||||||
invoke_from: InvokeFrom
|
|
||||||
call_depth: int
|
|
||||||
max_execution_steps: int
|
|
||||||
max_execution_time: int
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
"""Validate execution context parameters."""
|
|
||||||
if self.call_depth < 0:
|
|
||||||
raise ValueError("Call depth must be non-negative")
|
|
||||||
if self.max_execution_steps <= 0:
|
|
||||||
raise ValueError("Max execution steps must be positive")
|
|
||||||
if self.max_execution_time <= 0:
|
|
||||||
raise ValueError("Max execution time must be positive")
|
|
||||||
@ -1,12 +1,94 @@
|
|||||||
"""
|
"""GraphExecution aggregate root managing the overall graph execution state."""
|
||||||
GraphExecution aggregate root managing the overall graph execution state.
|
|
||||||
"""
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from importlib import import_module
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeState
|
||||||
|
|
||||||
from .node_execution import NodeExecution
|
from .node_execution import NodeExecution
|
||||||
|
|
||||||
|
|
||||||
|
class GraphExecutionErrorState(BaseModel):
|
||||||
|
"""Serializable representation of an execution error."""
|
||||||
|
|
||||||
|
module: str = Field(description="Module containing the exception class")
|
||||||
|
qualname: str = Field(description="Qualified name of the exception class")
|
||||||
|
message: str | None = Field(default=None, description="Exception message string")
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecutionState(BaseModel):
|
||||||
|
"""Serializable representation of a node execution entity."""
|
||||||
|
|
||||||
|
node_id: str
|
||||||
|
state: NodeState = Field(default=NodeState.UNKNOWN)
|
||||||
|
retry_count: int = Field(default=0)
|
||||||
|
execution_id: str | None = Field(default=None)
|
||||||
|
error: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphExecutionState(BaseModel):
|
||||||
|
"""Pydantic model describing serialized GraphExecution state."""
|
||||||
|
|
||||||
|
type: Literal["GraphExecution"] = Field(default="GraphExecution")
|
||||||
|
version: str = Field(default="1.0")
|
||||||
|
workflow_id: str
|
||||||
|
started: bool = Field(default=False)
|
||||||
|
completed: bool = Field(default=False)
|
||||||
|
aborted: bool = Field(default=False)
|
||||||
|
error: GraphExecutionErrorState | None = Field(default=None)
|
||||||
|
node_executions: list[NodeExecutionState] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
|
||||||
|
"""Convert an exception into its serializable representation."""
|
||||||
|
|
||||||
|
if error is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return GraphExecutionErrorState(
|
||||||
|
module=error.__class__.__module__,
|
||||||
|
qualname=error.__class__.__qualname__,
|
||||||
|
message=str(error),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
|
||||||
|
"""Locate an exception class from its module and qualified name."""
|
||||||
|
|
||||||
|
module = import_module(module_name)
|
||||||
|
attr: object = module
|
||||||
|
for part in qualname.split("."):
|
||||||
|
attr = getattr(attr, part)
|
||||||
|
|
||||||
|
if isinstance(attr, type) and issubclass(attr, Exception):
|
||||||
|
return attr
|
||||||
|
|
||||||
|
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
|
||||||
|
"""Reconstruct an exception instance from serialized data."""
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
exception_class = _resolve_exception_class(state.module, state.qualname)
|
||||||
|
if state.message is None:
|
||||||
|
return exception_class()
|
||||||
|
return exception_class(state.message)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to RuntimeError when reconstruction fails
|
||||||
|
if state.message is None:
|
||||||
|
return RuntimeError(state.qualname)
|
||||||
|
return RuntimeError(state.message)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphExecution:
|
class GraphExecution:
|
||||||
"""
|
"""
|
||||||
@ -69,3 +151,57 @@ class GraphExecution:
|
|||||||
if not self.error:
|
if not self.error:
|
||||||
return None
|
return None
|
||||||
return str(self.error)
|
return str(self.error)
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
"""Serialize the aggregate state into a JSON string."""
|
||||||
|
|
||||||
|
node_states = [
|
||||||
|
NodeExecutionState(
|
||||||
|
node_id=node_id,
|
||||||
|
state=node_execution.state,
|
||||||
|
retry_count=node_execution.retry_count,
|
||||||
|
execution_id=node_execution.execution_id,
|
||||||
|
error=node_execution.error,
|
||||||
|
)
|
||||||
|
for node_id, node_execution in sorted(self.node_executions.items())
|
||||||
|
]
|
||||||
|
|
||||||
|
state = GraphExecutionState(
|
||||||
|
workflow_id=self.workflow_id,
|
||||||
|
started=self.started,
|
||||||
|
completed=self.completed,
|
||||||
|
aborted=self.aborted,
|
||||||
|
error=_serialize_error(self.error),
|
||||||
|
node_executions=node_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return state.model_dump_json()
|
||||||
|
|
||||||
|
def loads(self, data: str) -> None:
|
||||||
|
"""Restore aggregate state from a serialized JSON string."""
|
||||||
|
|
||||||
|
state = GraphExecutionState.model_validate_json(data)
|
||||||
|
|
||||||
|
if state.type != "GraphExecution":
|
||||||
|
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||||
|
|
||||||
|
if state.version != "1.0":
|
||||||
|
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||||
|
|
||||||
|
if self.workflow_id != state.workflow_id:
|
||||||
|
raise ValueError("Serialized workflow_id does not match aggregate identity")
|
||||||
|
|
||||||
|
self.started = state.started
|
||||||
|
self.completed = state.completed
|
||||||
|
self.aborted = state.aborted
|
||||||
|
self.error = _deserialize_error(state.error)
|
||||||
|
self.node_executions = {
|
||||||
|
item.node_id: NodeExecution(
|
||||||
|
node_id=item.node_id,
|
||||||
|
state=item.state,
|
||||||
|
retry_count=item.retry_count,
|
||||||
|
execution_id=item.execution_id,
|
||||||
|
error=item.error,
|
||||||
|
)
|
||||||
|
for item in state.node_executions
|
||||||
|
}
|
||||||
|
|||||||
211
api/core/workflow/graph_engine/error_handler.py
Normal file
211
api/core/workflow/graph_engine/error_handler.py
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Main error handler that coordinates error strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, final
|
||||||
|
|
||||||
|
from core.workflow.enums import (
|
||||||
|
ErrorStrategy as ErrorStrategyEnum,
|
||||||
|
)
|
||||||
|
from core.workflow.enums import (
|
||||||
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
WorkflowNodeExecutionStatus,
|
||||||
|
)
|
||||||
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph_events import (
|
||||||
|
GraphNodeEventBase,
|
||||||
|
NodeRunExceptionEvent,
|
||||||
|
NodeRunFailedEvent,
|
||||||
|
NodeRunRetryEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.node_events import NodeRunResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .domain import GraphExecution
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class ErrorHandler:
|
||||||
|
"""
|
||||||
|
Coordinates error handling strategies for node failures.
|
||||||
|
|
||||||
|
This acts as a facade for the various error strategies,
|
||||||
|
selecting and applying the appropriate strategy based on
|
||||||
|
node configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
|
||||||
|
"""
|
||||||
|
Initialize the error handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: The workflow graph
|
||||||
|
graph_execution: The graph execution state
|
||||||
|
"""
|
||||||
|
self._graph = graph
|
||||||
|
self._graph_execution = graph_execution
|
||||||
|
|
||||||
|
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
||||||
|
"""
|
||||||
|
Handle a node failure event.
|
||||||
|
|
||||||
|
Selects and applies the appropriate error strategy based on
|
||||||
|
the node's configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The node failure event
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional new event to process, or None to abort
|
||||||
|
"""
|
||||||
|
node = self._graph.nodes[event.node_id]
|
||||||
|
# Get retry count from NodeExecution
|
||||||
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
|
retry_count = node_execution.retry_count
|
||||||
|
|
||||||
|
# First check if retry is configured and not exhausted
|
||||||
|
if node.retry and retry_count < node.retry_config.max_retries:
|
||||||
|
result = self._handle_retry(event, retry_count)
|
||||||
|
if result:
|
||||||
|
# Retry count will be incremented when NodeRunRetryEvent is handled
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Apply configured error strategy
|
||||||
|
strategy = node.error_strategy
|
||||||
|
|
||||||
|
match strategy:
|
||||||
|
case None:
|
||||||
|
return self._handle_abort(event)
|
||||||
|
case ErrorStrategyEnum.FAIL_BRANCH:
|
||||||
|
return self._handle_fail_branch(event)
|
||||||
|
case ErrorStrategyEnum.DEFAULT_VALUE:
|
||||||
|
return self._handle_default_value(event)
|
||||||
|
|
||||||
|
def _handle_abort(self, event: NodeRunFailedEvent):
|
||||||
|
"""
|
||||||
|
Handle error by aborting execution.
|
||||||
|
|
||||||
|
This is the default strategy when no other strategy is specified.
|
||||||
|
It stops the entire graph execution when a node fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The failure event
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None - signals abortion
|
||||||
|
"""
|
||||||
|
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
||||||
|
# Return None to signal that execution should stop
|
||||||
|
|
||||||
|
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
|
||||||
|
"""
|
||||||
|
Handle error by retrying the node.
|
||||||
|
|
||||||
|
This strategy re-attempts node execution up to a configured
|
||||||
|
maximum number of retries with configurable intervals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The failure event
|
||||||
|
retry_count: Current retry attempt count
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NodeRunRetryEvent if retry should occur, None otherwise
|
||||||
|
"""
|
||||||
|
node = self._graph.nodes[event.node_id]
|
||||||
|
|
||||||
|
# Check if we've exceeded max retries
|
||||||
|
if not node.retry or retry_count >= node.retry_config.max_retries:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Wait for retry interval
|
||||||
|
time.sleep(node.retry_config.retry_interval_seconds)
|
||||||
|
|
||||||
|
# Create retry event
|
||||||
|
return NodeRunRetryEvent(
|
||||||
|
id=event.id,
|
||||||
|
node_title=node.title,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_run_result=event.node_run_result,
|
||||||
|
start_at=event.start_at,
|
||||||
|
error=event.error,
|
||||||
|
retry_index=retry_count + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_fail_branch(self, event: NodeRunFailedEvent):
|
||||||
|
"""
|
||||||
|
Handle error by taking the fail branch.
|
||||||
|
|
||||||
|
This strategy converts failures to exceptions and routes execution
|
||||||
|
through a designated fail-branch edge.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The failure event
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NodeRunExceptionEvent to continue via fail branch
|
||||||
|
"""
|
||||||
|
outputs = {
|
||||||
|
"error_message": event.node_run_result.error,
|
||||||
|
"error_type": event.node_run_result.error_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
return NodeRunExceptionEvent(
|
||||||
|
id=event.id,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
start_at=event.start_at,
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
|
inputs=event.node_run_result.inputs,
|
||||||
|
process_data=event.node_run_result.process_data,
|
||||||
|
outputs=outputs,
|
||||||
|
edge_source_handle="fail-branch",
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
error=event.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_default_value(self, event: NodeRunFailedEvent):
|
||||||
|
"""
|
||||||
|
Handle error by using default values.
|
||||||
|
|
||||||
|
This strategy allows nodes to fail gracefully by providing
|
||||||
|
predefined default output values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The failure event
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NodeRunExceptionEvent with default values
|
||||||
|
"""
|
||||||
|
node = self._graph.nodes[event.node_id]
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
**node.default_value_dict,
|
||||||
|
"error_message": event.node_run_result.error,
|
||||||
|
"error_type": event.node_run_result.error_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
return NodeRunExceptionEvent(
|
||||||
|
id=event.id,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
start_at=event.start_at,
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
|
inputs=event.node_run_result.inputs,
|
||||||
|
process_data=event.node_run_result.process_data,
|
||||||
|
outputs=outputs,
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
error=event.error,
|
||||||
|
)
|
||||||
@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
Error handling strategies for graph engine.
|
|
||||||
|
|
||||||
This package implements different error recovery strategies using
|
|
||||||
the Strategy pattern for clean separation of concerns.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .abort_strategy import AbortStrategy
|
|
||||||
from .default_value_strategy import DefaultValueStrategy
|
|
||||||
from .error_handler import ErrorHandler
|
|
||||||
from .fail_branch_strategy import FailBranchStrategy
|
|
||||||
from .retry_strategy import RetryStrategy
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AbortStrategy",
|
|
||||||
"DefaultValueStrategy",
|
|
||||||
"ErrorHandler",
|
|
||||||
"FailBranchStrategy",
|
|
||||||
"RetryStrategy",
|
|
||||||
]
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
"""
|
|
||||||
Abort error strategy implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import final
|
|
||||||
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class AbortStrategy:
|
|
||||||
"""
|
|
||||||
Error strategy that aborts execution on failure.
|
|
||||||
|
|
||||||
This is the default strategy when no other strategy is specified.
|
|
||||||
It stops the entire graph execution when a node fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
|
||||||
"""
|
|
||||||
Handle error by aborting execution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The failure event
|
|
||||||
graph: The workflow graph
|
|
||||||
retry_count: Current retry attempt count (unused)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None - signals abortion
|
|
||||||
"""
|
|
||||||
_ = graph
|
|
||||||
_ = retry_count
|
|
||||||
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
|
|
||||||
|
|
||||||
# Return None to signal that execution should stop
|
|
||||||
return None
|
|
||||||
@ -1,58 +0,0 @@
|
|||||||
"""
|
|
||||||
Default value error strategy implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import final
|
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent
|
|
||||||
from core.workflow.node_events import NodeRunResult
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class DefaultValueStrategy:
|
|
||||||
"""
|
|
||||||
Error strategy that uses default values on failure.
|
|
||||||
|
|
||||||
This strategy allows nodes to fail gracefully by providing
|
|
||||||
predefined default output values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
|
||||||
"""
|
|
||||||
Handle error by using default values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The failure event
|
|
||||||
graph: The workflow graph
|
|
||||||
retry_count: Current retry attempt count (unused)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
NodeRunExceptionEvent with default values
|
|
||||||
"""
|
|
||||||
_ = retry_count
|
|
||||||
node = graph.nodes[event.node_id]
|
|
||||||
|
|
||||||
outputs = {
|
|
||||||
**node.default_value_dict,
|
|
||||||
"error_message": event.node_run_result.error,
|
|
||||||
"error_type": event.node_run_result.error_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
return NodeRunExceptionEvent(
|
|
||||||
id=event.id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
start_at=event.start_at,
|
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
|
||||||
inputs=event.node_run_result.inputs,
|
|
||||||
process_data=event.node_run_result.process_data,
|
|
||||||
outputs=outputs,
|
|
||||||
metadata={
|
|
||||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.DEFAULT_VALUE,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
@ -1,81 +0,0 @@
|
|||||||
"""
|
|
||||||
Main error handler that coordinates error strategies.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, final
|
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy as ErrorStrategyEnum
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
|
||||||
|
|
||||||
from .abort_strategy import AbortStrategy
|
|
||||||
from .default_value_strategy import DefaultValueStrategy
|
|
||||||
from .fail_branch_strategy import FailBranchStrategy
|
|
||||||
from .retry_strategy import RetryStrategy
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..domain import GraphExecution
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class ErrorHandler:
|
|
||||||
"""
|
|
||||||
Coordinates error handling strategies for node failures.
|
|
||||||
|
|
||||||
This acts as a facade for the various error strategies,
|
|
||||||
selecting and applying the appropriate strategy based on
|
|
||||||
node configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
|
|
||||||
"""
|
|
||||||
Initialize the error handler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph: The workflow graph
|
|
||||||
graph_execution: The graph execution state
|
|
||||||
"""
|
|
||||||
self._graph = graph
|
|
||||||
self._graph_execution = graph_execution
|
|
||||||
|
|
||||||
# Initialize strategies
|
|
||||||
self._abort_strategy = AbortStrategy()
|
|
||||||
self._retry_strategy = RetryStrategy()
|
|
||||||
self._fail_branch_strategy = FailBranchStrategy()
|
|
||||||
self._default_value_strategy = DefaultValueStrategy()
|
|
||||||
|
|
||||||
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
|
|
||||||
"""
|
|
||||||
Handle a node failure event.
|
|
||||||
|
|
||||||
Selects and applies the appropriate error strategy based on
|
|
||||||
the node's configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The node failure event
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional new event to process, or None to abort
|
|
||||||
"""
|
|
||||||
node = self._graph.nodes[event.node_id]
|
|
||||||
# Get retry count from NodeExecution
|
|
||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
|
||||||
retry_count = node_execution.retry_count
|
|
||||||
|
|
||||||
# First check if retry is configured and not exhausted
|
|
||||||
if node.retry and retry_count < node.retry_config.max_retries:
|
|
||||||
result = self._retry_strategy.handle_error(event, self._graph, retry_count)
|
|
||||||
if result:
|
|
||||||
# Retry count will be incremented when NodeRunRetryEvent is handled
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Apply configured error strategy
|
|
||||||
strategy = node.error_strategy
|
|
||||||
|
|
||||||
match strategy:
|
|
||||||
case None:
|
|
||||||
return self._abort_strategy.handle_error(event, self._graph, retry_count)
|
|
||||||
case ErrorStrategyEnum.FAIL_BRANCH:
|
|
||||||
return self._fail_branch_strategy.handle_error(event, self._graph, retry_count)
|
|
||||||
case ErrorStrategyEnum.DEFAULT_VALUE:
|
|
||||||
return self._default_value_strategy.handle_error(event, self._graph, retry_count)
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
"""
|
|
||||||
Fail branch error strategy implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import final
|
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent
|
|
||||||
from core.workflow.node_events import NodeRunResult
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class FailBranchStrategy:
|
|
||||||
"""
|
|
||||||
Error strategy that continues execution via a fail branch.
|
|
||||||
|
|
||||||
This strategy converts failures to exceptions and routes execution
|
|
||||||
through a designated fail-branch edge.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
|
||||||
"""
|
|
||||||
Handle error by taking the fail branch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The failure event
|
|
||||||
graph: The workflow graph
|
|
||||||
retry_count: Current retry attempt count (unused)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
NodeRunExceptionEvent to continue via fail branch
|
|
||||||
"""
|
|
||||||
_ = graph
|
|
||||||
_ = retry_count
|
|
||||||
outputs = {
|
|
||||||
"error_message": event.node_run_result.error,
|
|
||||||
"error_type": event.node_run_result.error_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
return NodeRunExceptionEvent(
|
|
||||||
id=event.id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
start_at=event.start_at,
|
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
|
||||||
inputs=event.node_run_result.inputs,
|
|
||||||
process_data=event.node_run_result.process_data,
|
|
||||||
outputs=outputs,
|
|
||||||
edge_source_handle="fail-branch",
|
|
||||||
metadata={
|
|
||||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategy.FAIL_BRANCH,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
@ -1,52 +0,0 @@
|
|||||||
"""
|
|
||||||
Retry error strategy implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import final
|
|
||||||
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunRetryEvent
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class RetryStrategy:
|
|
||||||
"""
|
|
||||||
Error strategy that retries failed nodes.
|
|
||||||
|
|
||||||
This strategy re-attempts node execution up to a configured
|
|
||||||
maximum number of retries with configurable intervals.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
|
||||||
"""
|
|
||||||
Handle error by retrying the node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The failure event
|
|
||||||
graph: The workflow graph
|
|
||||||
retry_count: Current retry attempt count
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
NodeRunRetryEvent if retry should occur, None otherwise
|
|
||||||
"""
|
|
||||||
node = graph.nodes[event.node_id]
|
|
||||||
|
|
||||||
# Check if we've exceeded max retries
|
|
||||||
if not node.retry or retry_count >= node.retry_config.max_retries:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Wait for retry interval
|
|
||||||
time.sleep(node.retry_config.retry_interval_seconds)
|
|
||||||
|
|
||||||
# Create retry event
|
|
||||||
return NodeRunRetryEvent(
|
|
||||||
id=event.id,
|
|
||||||
node_title=node.title,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
node_run_result=event.node_run_result,
|
|
||||||
start_at=event.start_at,
|
|
||||||
error=event.error,
|
|
||||||
retry_index=retry_count + 1,
|
|
||||||
)
|
|
||||||
@ -3,6 +3,7 @@ Event handler implementations for different event types.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from functools import singledispatchmethod
|
||||||
from typing import TYPE_CHECKING, final
|
from typing import TYPE_CHECKING, final
|
||||||
|
|
||||||
from core.workflow.entities import GraphRuntimeState
|
from core.workflow.entities import GraphRuntimeState
|
||||||
@ -31,9 +32,9 @@ from ..domain.graph_execution import GraphExecution
|
|||||||
from ..response_coordinator import ResponseStreamCoordinator
|
from ..response_coordinator import ResponseStreamCoordinator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..error_handling import ErrorHandler
|
from ..error_handler import ErrorHandler
|
||||||
|
from ..graph_state_manager import GraphStateManager
|
||||||
from ..graph_traversal import EdgeProcessor
|
from ..graph_traversal import EdgeProcessor
|
||||||
from ..state_management import UnifiedStateManager
|
|
||||||
from .event_manager import EventManager
|
from .event_manager import EventManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -56,7 +57,7 @@ class EventHandler:
|
|||||||
response_coordinator: ResponseStreamCoordinator,
|
response_coordinator: ResponseStreamCoordinator,
|
||||||
event_collector: "EventManager",
|
event_collector: "EventManager",
|
||||||
edge_processor: "EdgeProcessor",
|
edge_processor: "EdgeProcessor",
|
||||||
state_manager: "UnifiedStateManager",
|
state_manager: "GraphStateManager",
|
||||||
error_handler: "ErrorHandler",
|
error_handler: "ErrorHandler",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -81,7 +82,7 @@ class EventHandler:
|
|||||||
self._state_manager = state_manager
|
self._state_manager = state_manager
|
||||||
self._error_handler = error_handler
|
self._error_handler = error_handler
|
||||||
|
|
||||||
def handle_event(self, event: GraphNodeEventBase) -> None:
|
def dispatch(self, event: GraphNodeEventBase) -> None:
|
||||||
"""
|
"""
|
||||||
Handle any node event by dispatching to the appropriate handler.
|
Handle any node event by dispatching to the appropriate handler.
|
||||||
|
|
||||||
@ -92,42 +93,27 @@ class EventHandler:
|
|||||||
if event.in_loop_id or event.in_iteration_id:
|
if event.in_loop_id or event.in_iteration_id:
|
||||||
self._event_collector.collect(event)
|
self._event_collector.collect(event)
|
||||||
return
|
return
|
||||||
|
return self._dispatch(event)
|
||||||
|
|
||||||
# Handle specific event types
|
@singledispatchmethod
|
||||||
if isinstance(event, NodeRunStartedEvent):
|
def _dispatch(self, event: GraphNodeEventBase) -> None:
|
||||||
self._handle_node_started(event)
|
self._event_collector.collect(event)
|
||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
logger.warning("Unhandled event type: %s", type(event).__name__)
|
||||||
self._handle_stream_chunk(event)
|
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
|
||||||
self._handle_node_succeeded(event)
|
|
||||||
elif isinstance(event, NodeRunFailedEvent):
|
|
||||||
self._handle_node_failed(event)
|
|
||||||
elif isinstance(event, NodeRunExceptionEvent):
|
|
||||||
self._handle_node_exception(event)
|
|
||||||
elif isinstance(event, NodeRunRetryEvent):
|
|
||||||
self._handle_node_retry(event)
|
|
||||||
elif isinstance(
|
|
||||||
event,
|
|
||||||
(
|
|
||||||
NodeRunIterationStartedEvent,
|
|
||||||
NodeRunIterationNextEvent,
|
|
||||||
NodeRunIterationSucceededEvent,
|
|
||||||
NodeRunIterationFailedEvent,
|
|
||||||
NodeRunLoopStartedEvent,
|
|
||||||
NodeRunLoopNextEvent,
|
|
||||||
NodeRunLoopSucceededEvent,
|
|
||||||
NodeRunLoopFailedEvent,
|
|
||||||
NodeRunAgentLogEvent,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
# Iteration and loop events are collected directly
|
|
||||||
self._event_collector.collect(event)
|
|
||||||
else:
|
|
||||||
# Collect unhandled events
|
|
||||||
self._event_collector.collect(event)
|
|
||||||
logger.warning("Unhandled event type: %s", type(event).__name__)
|
|
||||||
|
|
||||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
@_dispatch.register(NodeRunIterationStartedEvent)
|
||||||
|
@_dispatch.register(NodeRunIterationNextEvent)
|
||||||
|
@_dispatch.register(NodeRunIterationSucceededEvent)
|
||||||
|
@_dispatch.register(NodeRunIterationFailedEvent)
|
||||||
|
@_dispatch.register(NodeRunLoopStartedEvent)
|
||||||
|
@_dispatch.register(NodeRunLoopNextEvent)
|
||||||
|
@_dispatch.register(NodeRunLoopSucceededEvent)
|
||||||
|
@_dispatch.register(NodeRunLoopFailedEvent)
|
||||||
|
@_dispatch.register(NodeRunAgentLogEvent)
|
||||||
|
def _(self, event: GraphNodeEventBase) -> None:
|
||||||
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
|
@_dispatch.register
|
||||||
|
def _(self, event: NodeRunStartedEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Handle node started event.
|
Handle node started event.
|
||||||
|
|
||||||
@ -144,7 +130,8 @@ class EventHandler:
|
|||||||
# Collect the event
|
# Collect the event
|
||||||
self._event_collector.collect(event)
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
@_dispatch.register
|
||||||
|
def _(self, event: NodeRunStreamChunkEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Handle stream chunk event with full processing.
|
Handle stream chunk event with full processing.
|
||||||
|
|
||||||
@ -158,7 +145,8 @@ class EventHandler:
|
|||||||
for stream_event in streaming_events:
|
for stream_event in streaming_events:
|
||||||
self._event_collector.collect(stream_event)
|
self._event_collector.collect(stream_event)
|
||||||
|
|
||||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
@_dispatch.register
|
||||||
|
def _(self, event: NodeRunSucceededEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Handle node success by coordinating subsystems.
|
Handle node success by coordinating subsystems.
|
||||||
|
|
||||||
@ -208,7 +196,8 @@ class EventHandler:
|
|||||||
# Collect the event
|
# Collect the event
|
||||||
self._event_collector.collect(event)
|
self._event_collector.collect(event)
|
||||||
|
|
||||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
@_dispatch.register
|
||||||
|
def _(self, event: NodeRunFailedEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Handle node failure using error handler.
|
Handle node failure using error handler.
|
||||||
|
|
||||||
@ -223,14 +212,15 @@ class EventHandler:
|
|||||||
|
|
||||||
if result:
|
if result:
|
||||||
# Process the resulting event (retry, exception, etc.)
|
# Process the resulting event (retry, exception, etc.)
|
||||||
self.handle_event(result)
|
self.dispatch(result)
|
||||||
else:
|
else:
|
||||||
# Abort execution
|
# Abort execution
|
||||||
self._graph_execution.fail(RuntimeError(event.error))
|
self._graph_execution.fail(RuntimeError(event.error))
|
||||||
self._event_collector.collect(event)
|
self._event_collector.collect(event)
|
||||||
self._state_manager.finish_execution(event.node_id)
|
self._state_manager.finish_execution(event.node_id)
|
||||||
|
|
||||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
@_dispatch.register
|
||||||
|
def _(self, event: NodeRunExceptionEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Handle node exception event (fail-branch strategy).
|
Handle node exception event (fail-branch strategy).
|
||||||
|
|
||||||
@ -241,7 +231,8 @@ class EventHandler:
|
|||||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||||
node_execution.mark_taken()
|
node_execution.mark_taken()
|
||||||
|
|
||||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
@_dispatch.register
|
||||||
|
def _(self, event: NodeRunRetryEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Handle node retry event.
|
Handle node retry event.
|
||||||
|
|
||||||
|
|||||||
@ -8,16 +8,16 @@ Domain-Driven Design principles for improved maintainability and testability.
|
|||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
import queue
|
import queue
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator
|
||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.entities import GraphRuntimeState
|
from core.workflow.entities import GraphRuntimeState
|
||||||
from core.workflow.enums import NodeExecutionType
|
from core.workflow.enums import NodeExecutionType
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
||||||
|
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphNodeEventBase,
|
GraphNodeEventBase,
|
||||||
@ -26,19 +26,19 @@ from core.workflow.graph_events import (
|
|||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
from .command_processing import AbortCommandHandler, CommandProcessor
|
from .command_processing import AbortCommandHandler, CommandProcessor
|
||||||
from .domain import ExecutionContext, GraphExecution
|
from .domain import GraphExecution
|
||||||
from .entities.commands import AbortCommand
|
from .entities.commands import AbortCommand
|
||||||
from .error_handling import ErrorHandler
|
from .error_handler import ErrorHandler
|
||||||
from .event_management import EventHandler, EventManager
|
from .event_management import EventHandler, EventManager
|
||||||
|
from .graph_state_manager import GraphStateManager
|
||||||
from .graph_traversal import EdgeProcessor, SkipPropagator
|
from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||||
from .layers.base import GraphEngineLayer
|
from .layers.base import GraphEngineLayer
|
||||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||||
from .protocols.command_channel import CommandChannel
|
from .protocols.command_channel import CommandChannel
|
||||||
|
from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state
|
||||||
from .response_coordinator import ResponseStreamCoordinator
|
from .response_coordinator import ResponseStreamCoordinator
|
||||||
from .state_management import UnifiedStateManager
|
|
||||||
from .worker_management import WorkerPool
|
from .worker_management import WorkerPool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -55,18 +55,9 @@ class GraphEngine:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tenant_id: str,
|
|
||||||
app_id: str,
|
|
||||||
workflow_id: str,
|
workflow_id: str,
|
||||||
user_id: str,
|
|
||||||
user_from: UserFrom,
|
|
||||||
invoke_from: InvokeFrom,
|
|
||||||
call_depth: int,
|
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
graph_config: Mapping[str, object],
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
max_execution_steps: int,
|
|
||||||
max_execution_time: int,
|
|
||||||
command_channel: CommandChannel,
|
command_channel: CommandChannel,
|
||||||
min_workers: int | None = None,
|
min_workers: int | None = None,
|
||||||
max_workers: int | None = None,
|
max_workers: int | None = None,
|
||||||
@ -75,27 +66,14 @@ class GraphEngine:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||||
|
|
||||||
# === Domain Models ===
|
|
||||||
# Execution context encapsulates workflow execution metadata
|
|
||||||
self._execution_context = ExecutionContext(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
app_id=app_id,
|
|
||||||
workflow_id=workflow_id,
|
|
||||||
user_id=user_id,
|
|
||||||
user_from=user_from,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
call_depth=call_depth,
|
|
||||||
max_execution_steps=max_execution_steps,
|
|
||||||
max_execution_time=max_execution_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Graph execution tracks the overall execution state
|
# Graph execution tracks the overall execution state
|
||||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||||
|
if graph_runtime_state.graph_execution_json != "":
|
||||||
|
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
|
||||||
|
|
||||||
# === Core Dependencies ===
|
# === Core Dependencies ===
|
||||||
# Graph structure and configuration
|
# Graph structure and configuration
|
||||||
self._graph = graph
|
self._graph = graph
|
||||||
self._graph_config = graph_config
|
|
||||||
self._graph_runtime_state = graph_runtime_state
|
self._graph_runtime_state = graph_runtime_state
|
||||||
self._command_channel = command_channel
|
self._command_channel = command_channel
|
||||||
|
|
||||||
@ -107,20 +85,28 @@ class GraphEngine:
|
|||||||
self._scale_down_idle_time = scale_down_idle_time
|
self._scale_down_idle_time = scale_down_idle_time
|
||||||
|
|
||||||
# === Execution Queues ===
|
# === Execution Queues ===
|
||||||
# Queue for nodes ready to execute
|
# Create ready queue from saved state or initialize new one
|
||||||
self._ready_queue: queue.Queue[str] = queue.Queue()
|
self._ready_queue: ReadyQueue
|
||||||
|
if self._graph_runtime_state.ready_queue_json == "":
|
||||||
|
self._ready_queue = InMemoryReadyQueue()
|
||||||
|
else:
|
||||||
|
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
|
||||||
|
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
|
||||||
|
|
||||||
# Queue for events generated during execution
|
# Queue for events generated during execution
|
||||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||||
|
|
||||||
# === State Management ===
|
# === State Management ===
|
||||||
# Unified state manager handles all node state transitions and queue operations
|
# Unified state manager handles all node state transitions and queue operations
|
||||||
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
|
self._state_manager = GraphStateManager(self._graph, self._ready_queue)
|
||||||
|
|
||||||
# === Response Coordination ===
|
# === Response Coordination ===
|
||||||
# Coordinates response streaming from response nodes
|
# Coordinates response streaming from response nodes
|
||||||
self._response_coordinator = ResponseStreamCoordinator(
|
self._response_coordinator = ResponseStreamCoordinator(
|
||||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||||
)
|
)
|
||||||
|
if graph_runtime_state.response_coordinator_json != "":
|
||||||
|
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
|
||||||
|
|
||||||
# === Event Management ===
|
# === Event Management ===
|
||||||
# Event manager handles both collection and emission of events
|
# Event manager handles both collection and emission of events
|
||||||
@ -216,7 +202,6 @@ class GraphEngine:
|
|||||||
event_handler=self._event_handler_registry,
|
event_handler=self._event_handler_registry,
|
||||||
event_collector=self._event_manager,
|
event_collector=self._event_manager,
|
||||||
execution_coordinator=self._execution_coordinator,
|
execution_coordinator=self._execution_coordinator,
|
||||||
max_execution_time=self._execution_context.max_execution_time,
|
|
||||||
event_emitter=self._event_manager,
|
event_emitter=self._event_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Unified state manager that combines node, edge, and execution tracking.
|
Graph state manager that combines node, edge, and execution tracking.
|
||||||
|
|
||||||
This is a proposed simplification that merges NodeStateManager, EdgeStateManager,
|
|
||||||
and ExecutionTracker into a single cohesive class.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import queue
|
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TypedDict, final
|
from typing import TypedDict, final
|
||||||
@ -13,6 +9,8 @@ from typing import TypedDict, final
|
|||||||
from core.workflow.enums import NodeState
|
from core.workflow.enums import NodeState
|
||||||
from core.workflow.graph import Edge, Graph
|
from core.workflow.graph import Edge, Graph
|
||||||
|
|
||||||
|
from .ready_queue import ReadyQueue
|
||||||
|
|
||||||
|
|
||||||
class EdgeStateAnalysis(TypedDict):
|
class EdgeStateAnalysis(TypedDict):
|
||||||
"""Analysis result for edge states."""
|
"""Analysis result for edge states."""
|
||||||
@ -23,24 +21,10 @@ class EdgeStateAnalysis(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class UnifiedStateManager:
|
class GraphStateManager:
|
||||||
"""
|
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
|
||||||
Unified manager for all graph state operations.
|
|
||||||
|
|
||||||
This class combines the responsibilities of:
|
|
||||||
- NodeStateManager: Node state transitions and ready queue
|
|
||||||
- EdgeStateManager: Edge state transitions and analysis
|
|
||||||
- ExecutionTracker: Tracking executing nodes
|
|
||||||
|
|
||||||
Benefits:
|
|
||||||
- Single lock for all state operations (reduced contention)
|
|
||||||
- Cohesive state management interface
|
|
||||||
- Simplified dependency injection
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
|
|
||||||
"""
|
"""
|
||||||
Initialize the unified state manager.
|
Initialize the state manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: The workflow graph
|
graph: The workflow graph
|
||||||
@ -9,8 +9,8 @@ from core.workflow.enums import NodeExecutionType
|
|||||||
from core.workflow.graph import Edge, Graph
|
from core.workflow.graph import Edge, Graph
|
||||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||||
|
|
||||||
|
from ..graph_state_manager import GraphStateManager
|
||||||
from ..response_coordinator import ResponseStreamCoordinator
|
from ..response_coordinator import ResponseStreamCoordinator
|
||||||
from ..state_management import UnifiedStateManager
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .skip_propagator import SkipPropagator
|
from .skip_propagator import SkipPropagator
|
||||||
@ -29,7 +29,7 @@ class EdgeProcessor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
state_manager: UnifiedStateManager,
|
state_manager: GraphStateManager,
|
||||||
response_coordinator: ResponseStreamCoordinator,
|
response_coordinator: ResponseStreamCoordinator,
|
||||||
skip_propagator: "SkipPropagator",
|
skip_propagator: "SkipPropagator",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import final
|
|||||||
|
|
||||||
from core.workflow.graph import Edge, Graph
|
from core.workflow.graph import Edge, Graph
|
||||||
|
|
||||||
from ..state_management import UnifiedStateManager
|
from ..graph_state_manager import GraphStateManager
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@ -22,7 +22,7 @@ class SkipPropagator:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
state_manager: UnifiedStateManager,
|
state_manager: GraphStateManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the skip propagator.
|
Initialize the skip propagator.
|
||||||
|
|||||||
@ -34,7 +34,6 @@ class Dispatcher:
|
|||||||
event_handler: "EventHandler",
|
event_handler: "EventHandler",
|
||||||
event_collector: EventManager,
|
event_collector: EventManager,
|
||||||
execution_coordinator: ExecutionCoordinator,
|
execution_coordinator: ExecutionCoordinator,
|
||||||
max_execution_time: int,
|
|
||||||
event_emitter: EventManager | None = None,
|
event_emitter: EventManager | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -45,14 +44,12 @@ class Dispatcher:
|
|||||||
event_handler: Event handler registry for processing events
|
event_handler: Event handler registry for processing events
|
||||||
event_collector: Event manager for collecting unhandled events
|
event_collector: Event manager for collecting unhandled events
|
||||||
execution_coordinator: Coordinator for execution flow
|
execution_coordinator: Coordinator for execution flow
|
||||||
max_execution_time: Maximum execution time in seconds
|
|
||||||
event_emitter: Optional event manager to signal completion
|
event_emitter: Optional event manager to signal completion
|
||||||
"""
|
"""
|
||||||
self._event_queue = event_queue
|
self._event_queue = event_queue
|
||||||
self._event_handler = event_handler
|
self._event_handler = event_handler
|
||||||
self._event_collector = event_collector
|
self._event_collector = event_collector
|
||||||
self._execution_coordinator = execution_coordinator
|
self._execution_coordinator = execution_coordinator
|
||||||
self._max_execution_time = max_execution_time
|
|
||||||
self._event_emitter = event_emitter
|
self._event_emitter = event_emitter
|
||||||
|
|
||||||
self._thread: threading.Thread | None = None
|
self._thread: threading.Thread | None = None
|
||||||
@ -89,7 +86,7 @@ class Dispatcher:
|
|||||||
try:
|
try:
|
||||||
event = self._event_queue.get(timeout=0.1)
|
event = self._event_queue.get(timeout=0.1)
|
||||||
# Route to the event handler
|
# Route to the event handler
|
||||||
self._event_handler.handle_event(event)
|
self._event_handler.dispatch(event)
|
||||||
self._event_queue.task_done()
|
self._event_queue.task_done()
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
# Check if execution is complete
|
# Check if execution is complete
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, final
|
|||||||
from ..command_processing import CommandProcessor
|
from ..command_processing import CommandProcessor
|
||||||
from ..domain import GraphExecution
|
from ..domain import GraphExecution
|
||||||
from ..event_management import EventManager
|
from ..event_management import EventManager
|
||||||
from ..state_management import UnifiedStateManager
|
from ..graph_state_manager import GraphStateManager
|
||||||
from ..worker_management import WorkerPool
|
from ..worker_management import WorkerPool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ class ExecutionCoordinator:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph_execution: GraphExecution,
|
graph_execution: GraphExecution,
|
||||||
state_manager: UnifiedStateManager,
|
state_manager: GraphStateManager,
|
||||||
event_handler: "EventHandler",
|
event_handler: "EventHandler",
|
||||||
event_collector: EventManager,
|
event_collector: EventManager,
|
||||||
command_processor: CommandProcessor,
|
command_processor: CommandProcessor,
|
||||||
|
|||||||
@ -1,31 +0,0 @@
|
|||||||
"""
|
|
||||||
Base error strategy protocol.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorStrategy(Protocol):
|
|
||||||
"""
|
|
||||||
Protocol for error handling strategies.
|
|
||||||
|
|
||||||
Each strategy implements a different approach to handling
|
|
||||||
node execution failures.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def handle_error(self, event: NodeRunFailedEvent, graph: Graph, retry_count: int) -> GraphNodeEventBase | None:
|
|
||||||
"""
|
|
||||||
Handle a node failure event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: The failure event
|
|
||||||
graph: The workflow graph
|
|
||||||
retry_count: Current retry attempt count
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional new event to process, or None to stop
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
12
api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
12
api/core/workflow/graph_engine/ready_queue/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Ready queue implementations for GraphEngine.
|
||||||
|
|
||||||
|
This package contains the protocol and implementations for managing
|
||||||
|
the queue of nodes ready for execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .factory import create_ready_queue_from_state
|
||||||
|
from .in_memory import InMemoryReadyQueue
|
||||||
|
from .protocol import ReadyQueue, ReadyQueueState
|
||||||
|
|
||||||
|
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]
|
||||||
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
35
api/core/workflow/graph_engine/ready_queue/factory.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Factory for creating ReadyQueue instances from serialized state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from .in_memory import InMemoryReadyQueue
|
||||||
|
from .protocol import ReadyQueueState
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .protocol import ReadyQueue
|
||||||
|
|
||||||
|
|
||||||
|
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
|
||||||
|
"""
|
||||||
|
Create a ReadyQueue instance from a serialized state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ReadyQueue instance initialized with the given state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the queue type is unknown or version is unsupported
|
||||||
|
"""
|
||||||
|
if state.type == "InMemoryReadyQueue":
|
||||||
|
if state.version != "1.0":
|
||||||
|
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
|
||||||
|
queue = InMemoryReadyQueue()
|
||||||
|
# Always pass as JSON string to loads()
|
||||||
|
queue.loads(state.model_dump_json())
|
||||||
|
return queue
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown ready queue type: {state.type}")
|
||||||
140
api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
140
api/core/workflow/graph_engine/ready_queue/in_memory.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
In-memory implementation of the ReadyQueue protocol.
|
||||||
|
|
||||||
|
This implementation wraps Python's standard queue.Queue and adds
|
||||||
|
serialization capabilities for state storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import queue
|
||||||
|
from typing import final
|
||||||
|
|
||||||
|
from .protocol import ReadyQueue, ReadyQueueState
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class InMemoryReadyQueue(ReadyQueue):
|
||||||
|
"""
|
||||||
|
In-memory ready queue implementation with serialization support.
|
||||||
|
|
||||||
|
This implementation uses Python's queue.Queue internally and provides
|
||||||
|
methods to serialize and restore the queue state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int = 0) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the in-memory ready queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maxsize: Maximum size of the queue (0 for unlimited)
|
||||||
|
"""
|
||||||
|
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
|
||||||
|
|
||||||
|
def put(self, item: str) -> None:
|
||||||
|
"""
|
||||||
|
Add a node ID to the ready queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: The node ID to add to the queue
|
||||||
|
"""
|
||||||
|
self._queue.put(item)
|
||||||
|
|
||||||
|
def get(self, timeout: float | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve and remove a node ID from the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait for an item (None for blocking)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The node ID retrieved from the queue
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
queue.Empty: If timeout expires and no item is available
|
||||||
|
"""
|
||||||
|
if timeout is None:
|
||||||
|
return self._queue.get(block=True)
|
||||||
|
return self._queue.get(timeout=timeout)
|
||||||
|
|
||||||
|
def task_done(self) -> None:
|
||||||
|
"""
|
||||||
|
Indicate that a previously retrieved task is complete.
|
||||||
|
|
||||||
|
Used by worker threads to signal task completion for
|
||||||
|
join() synchronization.
|
||||||
|
"""
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the queue is empty.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the queue has no items, False otherwise
|
||||||
|
"""
|
||||||
|
return self._queue.empty()
|
||||||
|
|
||||||
|
def qsize(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the approximate size of the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The approximate number of items in the queue
|
||||||
|
"""
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
"""
|
||||||
|
Serialize the queue state to a JSON string for storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON string containing the serialized queue state
|
||||||
|
"""
|
||||||
|
# Extract all items from the queue without removing them
|
||||||
|
items: list[str] = []
|
||||||
|
temp_items: list[str] = []
|
||||||
|
|
||||||
|
# Drain the queue temporarily to get all items
|
||||||
|
while not self._queue.empty():
|
||||||
|
try:
|
||||||
|
item = self._queue.get_nowait()
|
||||||
|
temp_items.append(item)
|
||||||
|
items.append(item)
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Put items back in the same order
|
||||||
|
for item in temp_items:
|
||||||
|
self._queue.put(item)
|
||||||
|
|
||||||
|
state = ReadyQueueState(
|
||||||
|
type="InMemoryReadyQueue",
|
||||||
|
version="1.0",
|
||||||
|
items=items,
|
||||||
|
)
|
||||||
|
return state.model_dump_json()
|
||||||
|
|
||||||
|
def loads(self, data: str) -> None:
|
||||||
|
"""
|
||||||
|
Restore the queue state from a JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The JSON string containing the serialized queue state to restore
|
||||||
|
"""
|
||||||
|
state = ReadyQueueState.model_validate_json(data)
|
||||||
|
|
||||||
|
if state.type != "InMemoryReadyQueue":
|
||||||
|
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||||
|
|
||||||
|
if state.version != "1.0":
|
||||||
|
raise ValueError(f"Unsupported version: {state.version}")
|
||||||
|
|
||||||
|
# Clear the current queue
|
||||||
|
while not self._queue.empty():
|
||||||
|
try:
|
||||||
|
self._queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Restore items
|
||||||
|
for item in state.items:
|
||||||
|
self._queue.put(item)
|
||||||
104
api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
104
api/core/workflow/graph_engine/ready_queue/protocol.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
ReadyQueue protocol for GraphEngine node execution queue.
|
||||||
|
|
||||||
|
This protocol defines the interface for managing the queue of nodes ready
|
||||||
|
for execution, supporting both in-memory and persistent storage scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ReadyQueueState(BaseModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for serialized ready queue state.
|
||||||
|
|
||||||
|
This defines the structure of the data returned by dumps()
|
||||||
|
and expected by loads() for ready queue serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
|
||||||
|
version: str = Field(description="Serialization format version")
|
||||||
|
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
|
||||||
|
|
||||||
|
|
||||||
|
class ReadyQueue(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for managing nodes ready for execution in GraphEngine.
|
||||||
|
|
||||||
|
This protocol defines the interface that any ready queue implementation
|
||||||
|
must provide, enabling both in-memory queues and persistent queues
|
||||||
|
that can be serialized for state storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def put(self, item: str) -> None:
|
||||||
|
"""
|
||||||
|
Add a node ID to the ready queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: The node ID to add to the queue
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get(self, timeout: float | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Retrieve and remove a node ID from the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait for an item (None for blocking)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The node ID retrieved from the queue
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
queue.Empty: If timeout expires and no item is available
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def task_done(self) -> None:
|
||||||
|
"""
|
||||||
|
Indicate that a previously retrieved task is complete.
|
||||||
|
|
||||||
|
Used by worker threads to signal task completion for
|
||||||
|
join() synchronization.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the queue is empty.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the queue has no items, False otherwise
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def qsize(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the approximate size of the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The approximate number of items in the queue
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
"""
|
||||||
|
Serialize the queue state to a JSON string for storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON string containing the serialized queue state
|
||||||
|
that can be persisted and later restored
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def loads(self, data: str) -> None:
|
||||||
|
"""
|
||||||
|
Restore the queue state from a JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The JSON string containing the serialized queue state to restore
|
||||||
|
"""
|
||||||
|
...
|
||||||
@ -9,9 +9,11 @@ import logging
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import TypeAlias, final
|
from typing import Literal, TypeAlias, final
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import NodeExecutionType, NodeState
|
from core.workflow.enums import NodeExecutionType, NodeState
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
@ -28,6 +30,43 @@ NodeID: TypeAlias = str
|
|||||||
EdgeID: TypeAlias = str
|
EdgeID: TypeAlias = str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseSessionState(BaseModel):
|
||||||
|
"""Serializable representation of a response session."""
|
||||||
|
|
||||||
|
node_id: str
|
||||||
|
index: int = Field(default=0, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamBufferState(BaseModel):
|
||||||
|
"""Serializable representation of buffered stream chunks."""
|
||||||
|
|
||||||
|
selector: tuple[str, ...]
|
||||||
|
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamPositionState(BaseModel):
|
||||||
|
"""Serializable representation for stream read positions."""
|
||||||
|
|
||||||
|
selector: tuple[str, ...]
|
||||||
|
position: int = Field(default=0, ge=0)
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseStreamCoordinatorState(BaseModel):
|
||||||
|
"""Serialized snapshot of ResponseStreamCoordinator."""
|
||||||
|
|
||||||
|
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
|
||||||
|
version: str = Field(default="1.0")
|
||||||
|
response_nodes: Sequence[str] = Field(default_factory=list)
|
||||||
|
active_session: ResponseSessionState | None = None
|
||||||
|
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||||
|
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
|
||||||
|
node_execution_ids: dict[str, str] = Field(default_factory=dict)
|
||||||
|
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
|
||||||
|
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
|
||||||
|
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
|
||||||
|
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class ResponseStreamCoordinator:
|
class ResponseStreamCoordinator:
|
||||||
"""
|
"""
|
||||||
@ -69,6 +108,8 @@ class ResponseStreamCoordinator:
|
|||||||
|
|
||||||
def register(self, response_node_id: NodeID) -> None:
|
def register(self, response_node_id: NodeID) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
if response_node_id in self._response_nodes:
|
||||||
|
return
|
||||||
self._response_nodes.add(response_node_id)
|
self._response_nodes.add(response_node_id)
|
||||||
|
|
||||||
# Build and save paths map for this response node
|
# Build and save paths map for this response node
|
||||||
@ -558,3 +599,98 @@ class ResponseStreamCoordinator:
|
|||||||
"""
|
"""
|
||||||
key = tuple(selector)
|
key = tuple(selector)
|
||||||
return key in self._closed_streams
|
return key in self._closed_streams
|
||||||
|
|
||||||
|
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
|
||||||
|
"""Convert an in-memory session into its serializable form."""
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
return None
|
||||||
|
return ResponseSessionState(node_id=session.node_id, index=session.index)
|
||||||
|
|
||||||
|
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
|
||||||
|
"""Rebuild a response session from serialized data."""
|
||||||
|
|
||||||
|
node = self._graph.nodes.get(session_state.node_id)
|
||||||
|
if node is None:
|
||||||
|
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
|
||||||
|
|
||||||
|
session = ResponseSession.from_node(node)
|
||||||
|
session.index = session_state.index
|
||||||
|
return session
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
"""Serialize coordinator state to JSON."""
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
state = ResponseStreamCoordinatorState(
|
||||||
|
response_nodes=sorted(self._response_nodes),
|
||||||
|
active_session=self._serialize_session(self._active_session),
|
||||||
|
waiting_sessions=[
|
||||||
|
session_state
|
||||||
|
for session in list(self._waiting_sessions)
|
||||||
|
if (session_state := self._serialize_session(session)) is not None
|
||||||
|
],
|
||||||
|
pending_sessions=[
|
||||||
|
session_state
|
||||||
|
for _, session in sorted(self._response_sessions.items())
|
||||||
|
if (session_state := self._serialize_session(session)) is not None
|
||||||
|
],
|
||||||
|
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
|
||||||
|
paths_map={
|
||||||
|
node_id: [path.edges.copy() for path in paths]
|
||||||
|
for node_id, paths in sorted(self._paths_maps.items())
|
||||||
|
},
|
||||||
|
stream_buffers=[
|
||||||
|
StreamBufferState(
|
||||||
|
selector=selector,
|
||||||
|
events=[event.model_copy(deep=True) for event in events],
|
||||||
|
)
|
||||||
|
for selector, events in sorted(self._stream_buffers.items())
|
||||||
|
],
|
||||||
|
stream_positions=[
|
||||||
|
StreamPositionState(selector=selector, position=position)
|
||||||
|
for selector, position in sorted(self._stream_positions.items())
|
||||||
|
],
|
||||||
|
closed_streams=sorted(self._closed_streams),
|
||||||
|
)
|
||||||
|
return state.model_dump_json()
|
||||||
|
|
||||||
|
def loads(self, data: str) -> None:
|
||||||
|
"""Restore coordinator state from JSON."""
|
||||||
|
|
||||||
|
state = ResponseStreamCoordinatorState.model_validate_json(data)
|
||||||
|
|
||||||
|
if state.type != "ResponseStreamCoordinator":
|
||||||
|
raise ValueError(f"Invalid serialized data type: {state.type}")
|
||||||
|
|
||||||
|
if state.version != "1.0":
|
||||||
|
raise ValueError(f"Unsupported serialized version: {state.version}")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._response_nodes = set(state.response_nodes)
|
||||||
|
self._paths_maps = {
|
||||||
|
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
|
||||||
|
for node_id, paths in state.paths_map.items()
|
||||||
|
}
|
||||||
|
self._node_execution_ids = dict(state.node_execution_ids)
|
||||||
|
|
||||||
|
self._stream_buffers = {
|
||||||
|
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
|
||||||
|
for buffer in state.stream_buffers
|
||||||
|
}
|
||||||
|
self._stream_positions = {
|
||||||
|
tuple(position.selector): position.position for position in state.stream_positions
|
||||||
|
}
|
||||||
|
for selector in self._stream_buffers:
|
||||||
|
self._stream_positions.setdefault(selector, 0)
|
||||||
|
|
||||||
|
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
|
||||||
|
|
||||||
|
self._waiting_sessions = deque(
|
||||||
|
self._session_from_state(session_state) for session_state in state.waiting_sessions
|
||||||
|
)
|
||||||
|
self._response_sessions = {
|
||||||
|
session_state.node_id: self._session_from_state(session_state)
|
||||||
|
for session_state in state.pending_sessions
|
||||||
|
}
|
||||||
|
self._active_session = self._session_from_state(state.active_session) if state.active_session else None
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class Path:
|
|||||||
Note: This is an internal class not exposed in the public API.
|
Note: This is an internal class not exposed in the public API.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
edges: list[EdgeID] = field(default_factory=list)
|
edges: list[EdgeID] = field(default_factory=list[EdgeID])
|
||||||
|
|
||||||
def contains_edge(self, edge_id: EdgeID) -> bool:
|
def contains_edge(self, edge_id: EdgeID) -> bool:
|
||||||
"""Check if this path contains the given edge."""
|
"""Check if this path contains the given edge."""
|
||||||
|
|||||||
@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
State management subsystem for graph engine.
|
|
||||||
|
|
||||||
This package manages node states, edge states, and execution tracking
|
|
||||||
during workflow graph execution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .unified_state_manager import UnifiedStateManager
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"UnifiedStateManager",
|
|
||||||
]
|
|
||||||
@ -22,6 +22,8 @@ from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
|||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from libs.flask_utils import preserve_flask_contexts
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
|
|
||||||
|
from .ready_queue import ReadyQueue
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class Worker(threading.Thread):
|
class Worker(threading.Thread):
|
||||||
@ -35,7 +37,7 @@ class Worker(threading.Thread):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ready_queue: queue.Queue[str],
|
ready_queue: ReadyQueue,
|
||||||
event_queue: queue.Queue[GraphNodeEventBase],
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
worker_id: int = 0,
|
worker_id: int = 0,
|
||||||
@ -46,7 +48,7 @@ class Worker(threading.Thread):
|
|||||||
Initialize worker thread.
|
Initialize worker thread.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ready_queue: Queue containing node IDs ready for execution
|
ready_queue: Ready queue containing node IDs ready for execution
|
||||||
event_queue: Queue for pushing execution events
|
event_queue: Queue for pushing execution events
|
||||||
graph: Graph containing nodes to execute
|
graph: Graph containing nodes to execute
|
||||||
worker_id: Unique identifier for this worker
|
worker_id: Unique identifier for this worker
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from configs import dify_config
|
|||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_events import GraphNodeEventBase
|
from core.workflow.graph_events import GraphNodeEventBase
|
||||||
|
|
||||||
|
from ..ready_queue import ReadyQueue
|
||||||
from ..worker import Worker
|
from ..worker import Worker
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -35,7 +36,7 @@ class WorkerPool:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ready_queue: queue.Queue[str],
|
ready_queue: ReadyQueue,
|
||||||
event_queue: queue.Queue[GraphNodeEventBase],
|
event_queue: queue.Queue[GraphNodeEventBase],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
flask_app: "Flask | None" = None,
|
flask_app: "Flask | None" = None,
|
||||||
@ -49,7 +50,7 @@ class WorkerPool:
|
|||||||
Initialize the simple worker pool.
|
Initialize the simple worker pool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ready_queue: Queue of nodes ready for execution
|
ready_queue: Ready queue for nodes ready for execution
|
||||||
event_queue: Queue for worker events
|
event_queue: Queue for worker events
|
||||||
graph: The workflow graph
|
graph: The workflow graph
|
||||||
flask_app: Optional Flask app for context preservation
|
flask_app: Optional Flask app for context preservation
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -14,4 +14,4 @@ class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
|
|||||||
error: str | None = Field(..., description="error")
|
error: str | None = Field(..., description="error")
|
||||||
status: str = Field(..., description="status")
|
status: str = Field(..., description="status")
|
||||||
data: Mapping[str, Any] = Field(..., description="data")
|
data: Mapping[str, Any] = Field(..., description="data")
|
||||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
@ -19,9 +17,9 @@ class GraphNodeEventBase(GraphEngineEvent):
|
|||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
|
|
||||||
in_iteration_id: Optional[str] = None
|
in_iteration_id: str | None = None
|
||||||
"""iteration id if node is in iteration"""
|
"""iteration id if node is in iteration"""
|
||||||
in_loop_id: Optional[str] = None
|
in_loop_id: str | None = None
|
||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
|
|
||||||
# The version of the node, or "1" if not specified.
|
# The version of the node, or "1" if not specified.
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from core.workflow.graph_events import BaseGraphEvent
|
from core.workflow.graph_events import BaseGraphEvent
|
||||||
@ -10,7 +8,7 @@ class GraphRunStartedEvent(BaseGraphEvent):
|
|||||||
|
|
||||||
|
|
||||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||||
outputs: Optional[dict[str, Any]] = None
|
outputs: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class GraphRunFailedEvent(BaseGraphEvent):
|
class GraphRunFailedEvent(BaseGraphEvent):
|
||||||
@ -20,11 +18,11 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
|||||||
|
|
||||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||||
exceptions_count: int = Field(..., description="exception count")
|
exceptions_count: int = Field(..., description="exception count")
|
||||||
outputs: Optional[dict[str, Any]] = None
|
outputs: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class GraphRunAbortedEvent(BaseGraphEvent):
|
class GraphRunAbortedEvent(BaseGraphEvent):
|
||||||
"""Event emitted when a graph run is aborted by user command."""
|
"""Event emitted when a graph run is aborted by user command."""
|
||||||
|
|
||||||
reason: Optional[str] = Field(default=None, description="reason for abort")
|
reason: str | None = Field(default=None, description="reason for abort")
|
||||||
outputs: Optional[dict[str, Any]] = Field(default=None, description="partial outputs if any")
|
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -10,31 +10,31 @@ from .base import GraphNodeEventBase
|
|||||||
class NodeRunIterationStartedEvent(GraphNodeEventBase):
|
class NodeRunIterationStartedEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class NodeRunIterationNextEvent(GraphNodeEventBase):
|
class NodeRunIterationNextEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
index: int = Field(..., description="index")
|
index: int = Field(..., description="index")
|
||||||
pre_iteration_output: Optional[Any] = None
|
pre_iteration_output: Any = None
|
||||||
|
|
||||||
|
|
||||||
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
|
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
|
|
||||||
class NodeRunIterationFailedEvent(GraphNodeEventBase):
|
class NodeRunIterationFailedEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
error: str = Field(..., description="failed reason")
|
error: str = Field(..., description="failed reason")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -10,31 +10,31 @@ from .base import GraphNodeEventBase
|
|||||||
class NodeRunLoopStartedEvent(GraphNodeEventBase):
|
class NodeRunLoopStartedEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class NodeRunLoopNextEvent(GraphNodeEventBase):
|
class NodeRunLoopNextEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
index: int = Field(..., description="index")
|
index: int = Field(..., description="index")
|
||||||
pre_loop_output: Optional[Any] = None
|
pre_loop_output: Any = None
|
||||||
|
|
||||||
|
|
||||||
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
|
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
|
|
||||||
class NodeRunLoopFailedEvent(GraphNodeEventBase):
|
class NodeRunLoopFailedEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
error: str = Field(..., description="failed reason")
|
error: str = Field(..., description="failed reason")
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -12,9 +11,8 @@ from .base import GraphNodeEventBase
|
|||||||
|
|
||||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||||
node_title: str
|
node_title: str
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: str | None = None
|
||||||
parallel_mode_run_id: Optional[str] = None
|
agent_strategy: AgentNodeStrategyInit | None = None
|
||||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
|
||||||
start_at: datetime = Field(..., description="node start time")
|
start_at: datetime = Field(..., description="node start time")
|
||||||
|
|
||||||
# FIXME(-LAN-): only for ToolNode
|
# FIXME(-LAN-): only for ToolNode
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -14,5 +14,5 @@ class AgentLogEvent(NodeEventBase):
|
|||||||
error: str | None = Field(..., description="error")
|
error: str | None = Field(..., description="error")
|
||||||
status: str = Field(..., description="status")
|
status: str = Field(..., description="status")
|
||||||
data: Mapping[str, Any] = Field(..., description="data")
|
data: Mapping[str, Any] = Field(..., description="data")
|
||||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata")
|
||||||
node_id: str = Field(..., description="node id")
|
node_id: str = Field(..., description="node id")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -9,28 +9,28 @@ from .base import NodeEventBase
|
|||||||
|
|
||||||
class IterationStartedEvent(NodeEventBase):
|
class IterationStartedEvent(NodeEventBase):
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class IterationNextEvent(NodeEventBase):
|
class IterationNextEvent(NodeEventBase):
|
||||||
index: int = Field(..., description="index")
|
index: int = Field(..., description="index")
|
||||||
pre_iteration_output: Optional[Any] = None
|
pre_iteration_output: Any = None
|
||||||
|
|
||||||
|
|
||||||
class IterationSucceededEvent(NodeEventBase):
|
class IterationSucceededEvent(NodeEventBase):
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
|
|
||||||
class IterationFailedEvent(NodeEventBase):
|
class IterationFailedEvent(NodeEventBase):
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
error: str = Field(..., description="failed reason")
|
error: str = Field(..., description="failed reason")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -9,28 +9,28 @@ from .base import NodeEventBase
|
|||||||
|
|
||||||
class LoopStartedEvent(NodeEventBase):
|
class LoopStartedEvent(NodeEventBase):
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class LoopNextEvent(NodeEventBase):
|
class LoopNextEvent(NodeEventBase):
|
||||||
index: int = Field(..., description="index")
|
index: int = Field(..., description="index")
|
||||||
pre_loop_output: Optional[Any] = None
|
pre_loop_output: Any = None
|
||||||
|
|
||||||
|
|
||||||
class LoopSucceededEvent(NodeEventBase):
|
class LoopSucceededEvent(NodeEventBase):
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
|
|
||||||
class LoopFailedEvent(NodeEventBase):
|
class LoopFailedEvent(NodeEventBase):
|
||||||
start_at: datetime = Field(..., description="start at")
|
start_at: datetime = Field(..., description="start at")
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Optional[Mapping[str, Any]] = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
error: str = Field(..., description="failed reason")
|
error: str = Field(..., description="failed reason")
|
||||||
|
|||||||
@ -33,7 +33,13 @@ from core.workflow.enums import (
|
|||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
)
|
)
|
||||||
from core.workflow.node_events import AgentLogEvent, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
from core.workflow.node_events import (
|
||||||
|
AgentLogEvent,
|
||||||
|
NodeEventBase,
|
||||||
|
NodeRunResult,
|
||||||
|
StreamChunkEvent,
|
||||||
|
StreamCompletedEvent,
|
||||||
|
)
|
||||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
@ -93,7 +99,7 @@ class AgentNode(Node):
|
|||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -482,7 +488,7 @@ class AgentNode(Node):
|
|||||||
node_type: NodeType,
|
node_type: NodeType,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
node_execution_id: str,
|
node_execution_id: str,
|
||||||
) -> Generator:
|
) -> Generator[NodeEventBase, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
from functools import singledispatchmethod
|
||||||
|
from typing import Any, ClassVar
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphNodeEventBase,
|
GraphNodeEventBase,
|
||||||
NodeRunAgentLogEvent,
|
NodeRunAgentLogEvent,
|
||||||
@ -45,11 +46,6 @@ from models.enums import UserFrom
|
|||||||
|
|
||||||
from .entities import BaseNodeData, RetryConfig
|
from .entities import BaseNodeData, RetryConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType
|
|
||||||
from core.workflow.node_events import NodeRunResult
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -88,14 +84,14 @@ class Node:
|
|||||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(self) -> "NodeRunResult | Generator[GraphNodeEventBase, None, None]":
|
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def run(self) -> "Generator[GraphNodeEventBase, None, None]":
|
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||||
# Generate a single node execution ID to use for all events
|
# Generate a single node execution ID to use for all events
|
||||||
if not self._node_execution_id:
|
if not self._node_execution_id:
|
||||||
self._node_execution_id = str(uuid4())
|
self._node_execution_id = str(uuid4())
|
||||||
@ -151,12 +147,14 @@ class Node:
|
|||||||
|
|
||||||
# Handle event stream
|
# Handle event stream
|
||||||
for event in result:
|
for event in result:
|
||||||
if isinstance(event, NodeEventBase):
|
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
|
||||||
event = self._convert_node_event_to_graph_node_event(event)
|
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||||
|
yield self._dispatch(event)
|
||||||
if not event.in_iteration_id and not event.in_loop_id:
|
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||||
event.id = self._node_execution_id
|
event.id = self._node_execution_id
|
||||||
yield event
|
yield event
|
||||||
|
else:
|
||||||
|
yield event
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Node %s failed to run", self._node_id)
|
logger.exception("Node %s failed to run", self._node_id)
|
||||||
result = NodeRunResult(
|
result = NodeRunResult(
|
||||||
@ -249,7 +247,7 @@ class Node:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -270,7 +268,7 @@ class Node:
|
|||||||
# to BaseNodeData properties in a type-safe way
|
# to BaseNodeData properties in a type-safe way
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_error_strategy(self) -> Optional["ErrorStrategy"]:
|
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||||
"""Get the error strategy for this node."""
|
"""Get the error strategy for this node."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -301,7 +299,7 @@ class Node:
|
|||||||
|
|
||||||
# Public interface properties that delegate to abstract methods
|
# Public interface properties that delegate to abstract methods
|
||||||
@property
|
@property
|
||||||
def error_strategy(self) -> Optional["ErrorStrategy"]:
|
def error_strategy(self) -> ErrorStrategy | None:
|
||||||
"""Get the error strategy for this node."""
|
"""Get the error strategy for this node."""
|
||||||
return self._get_error_strategy()
|
return self._get_error_strategy()
|
||||||
|
|
||||||
@ -344,29 +342,15 @@ class Node:
|
|||||||
start_at=self._start_at,
|
start_at=self._start_at,
|
||||||
node_run_result=result,
|
node_run_result=result,
|
||||||
)
|
)
|
||||||
raise Exception(f"result status {result.status} not supported")
|
case _:
|
||||||
|
raise Exception(f"result status {result.status} not supported")
|
||||||
|
|
||||||
def _convert_node_event_to_graph_node_event(self, event: NodeEventBase) -> GraphNodeEventBase:
|
@singledispatchmethod
|
||||||
handler_maps: dict[type[NodeEventBase], Callable[[Any], GraphNodeEventBase]] = {
|
def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase:
|
||||||
StreamChunkEvent: self._handle_stream_chunk_event,
|
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
||||||
StreamCompletedEvent: self._handle_stream_completed_event,
|
|
||||||
AgentLogEvent: self._handle_agent_log_event,
|
|
||||||
LoopStartedEvent: self._handle_loop_started_event,
|
|
||||||
LoopNextEvent: self._handle_loop_next_event,
|
|
||||||
LoopSucceededEvent: self._handle_loop_succeeded_event,
|
|
||||||
LoopFailedEvent: self._handle_loop_failed_event,
|
|
||||||
IterationStartedEvent: self._handle_iteration_started_event,
|
|
||||||
IterationNextEvent: self._handle_iteration_next_event,
|
|
||||||
IterationSucceededEvent: self._handle_iteration_succeeded_event,
|
|
||||||
IterationFailedEvent: self._handle_iteration_failed_event,
|
|
||||||
RunRetrieverResourceEvent: self._handle_run_retriever_resource_event,
|
|
||||||
}
|
|
||||||
handler = handler_maps.get(type(event))
|
|
||||||
if not handler:
|
|
||||||
raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}")
|
|
||||||
return handler(event)
|
|
||||||
|
|
||||||
def _handle_stream_chunk_event(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||||
return NodeRunStreamChunkEvent(
|
return NodeRunStreamChunkEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -376,7 +360,8 @@ class Node:
|
|||||||
is_final=event.is_final,
|
is_final=event.is_final,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_stream_completed_event(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||||
match event.node_run_result.status:
|
match event.node_run_result.status:
|
||||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
return NodeRunSucceededEvent(
|
return NodeRunSucceededEvent(
|
||||||
@ -395,9 +380,13 @@ class Node:
|
|||||||
node_run_result=event.node_run_result,
|
node_run_result=event.node_run_result,
|
||||||
error=event.node_run_result.error,
|
error=event.node_run_result.error,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(f"Node {self._node_id} does not support status {event.node_run_result.status}")
|
case _:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_agent_log_event(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||||
return NodeRunAgentLogEvent(
|
return NodeRunAgentLogEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -412,7 +401,8 @@ class Node:
|
|||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_loop_started_event(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||||
return NodeRunLoopStartedEvent(
|
return NodeRunLoopStartedEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -424,7 +414,8 @@ class Node:
|
|||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_loop_next_event(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
|
||||||
return NodeRunLoopNextEvent(
|
return NodeRunLoopNextEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -434,7 +425,8 @@ class Node:
|
|||||||
pre_loop_output=event.pre_loop_output,
|
pre_loop_output=event.pre_loop_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_loop_succeeded_event(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
|
||||||
return NodeRunLoopSucceededEvent(
|
return NodeRunLoopSucceededEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -447,7 +439,8 @@ class Node:
|
|||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_loop_failed_event(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
|
||||||
return NodeRunLoopFailedEvent(
|
return NodeRunLoopFailedEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -461,7 +454,8 @@ class Node:
|
|||||||
error=event.error,
|
error=event.error,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_iteration_started_event(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
|
||||||
return NodeRunIterationStartedEvent(
|
return NodeRunIterationStartedEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -473,7 +467,8 @@ class Node:
|
|||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_iteration_next_event(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
|
||||||
return NodeRunIterationNextEvent(
|
return NodeRunIterationNextEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -483,7 +478,8 @@ class Node:
|
|||||||
pre_iteration_output=event.pre_iteration_output,
|
pre_iteration_output=event.pre_iteration_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_iteration_succeeded_event(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
|
||||||
return NodeRunIterationSucceededEvent(
|
return NodeRunIterationSucceededEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -496,7 +492,8 @@ class Node:
|
|||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_iteration_failed_event(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
|
||||||
return NodeRunIterationFailedEvent(
|
return NodeRunIterationFailedEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
@ -510,7 +507,8 @@ class Node:
|
|||||||
error=event.error,
|
error=event.error,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_run_retriever_resource_event(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
@_dispatch.register
|
||||||
|
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||||
return NodeRunRetrieverResourceEvent(
|
return NodeRunRetrieverResourceEvent(
|
||||||
id=self._node_execution_id,
|
id=self._node_execution_id,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||||
@ -49,7 +49,7 @@ class CodeNode(Node):
|
|||||||
return self._node_data
|
return self._node_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: dict | None = None):
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
Get default config of node.
|
Get default config of node.
|
||||||
:param filters: filter by node config parameters.
|
:param filters: filter by node config parameters.
|
||||||
@ -57,7 +57,7 @@ class CodeNode(Node):
|
|||||||
"""
|
"""
|
||||||
code_language = CodeLanguage.PYTHON3
|
code_language = CodeLanguage.PYTHON3
|
||||||
if filters:
|
if filters:
|
||||||
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
|
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||||
|
|
||||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||||
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal, Self
|
||||||
|
|
||||||
from pydantic import AfterValidator, BaseModel
|
from pydantic import AfterValidator, BaseModel
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
|
|||||||
|
|
||||||
class Output(BaseModel):
|
class Output(BaseModel):
|
||||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
children: dict[str, Self] | None = None
|
||||||
|
|
||||||
class Dependency(BaseModel):
|
class Dependency(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|||||||
@ -58,7 +58,7 @@ class HttpRequestNode(Node):
|
|||||||
return self._node_data
|
return self._node_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: dict[str, Any] | None = None):
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
"type": "http-request",
|
"type": "http-request",
|
||||||
"config": {
|
"config": {
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
outputs: list[Any] = Field(default_factory=list)
|
outputs: list[Any] = Field(default_factory=list)
|
||||||
current_output: Any | None = None
|
current_output: Any = None
|
||||||
|
|
||||||
class MetaData(BaseIterationState.MetaData):
|
class MetaData(BaseIterationState.MetaData):
|
||||||
"""
|
"""
|
||||||
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
|
|||||||
|
|
||||||
iterator_length: int
|
iterator_length: int
|
||||||
|
|
||||||
def get_last_output(self) -> Any | None:
|
def get_last_output(self) -> Any:
|
||||||
"""
|
"""
|
||||||
Get last output.
|
Get last output.
|
||||||
"""
|
"""
|
||||||
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
|
|||||||
return self.outputs[-1]
|
return self.outputs[-1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_current_output(self) -> Any | None:
|
def get_current_output(self) -> Any:
|
||||||
"""
|
"""
|
||||||
Get current output.
|
Get current output.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any, Union, cast
|
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||||
|
|
||||||
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from core.variables import IntegerVariable, NoneSegment
|
from core.variables import IntegerVariable, NoneSegment
|
||||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||||
@ -23,6 +26,7 @@ from core.workflow.node_events import (
|
|||||||
IterationNextEvent,
|
IterationNextEvent,
|
||||||
IterationStartedEvent,
|
IterationStartedEvent,
|
||||||
IterationSucceededEvent,
|
IterationSucceededEvent,
|
||||||
|
NodeEventBase,
|
||||||
NodeRunResult,
|
NodeRunResult,
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
@ -45,6 +49,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||||
|
|
||||||
|
|
||||||
class IterationNode(Node):
|
class IterationNode(Node):
|
||||||
"""
|
"""
|
||||||
@ -77,7 +83,7 @@ class IterationNode(Node):
|
|||||||
return self._node_data
|
return self._node_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: dict | None = None):
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
"type": "iteration",
|
"type": "iteration",
|
||||||
"config": {
|
"config": {
|
||||||
@ -91,44 +97,21 @@ class IterationNode(Node):
|
|||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
variable = self._get_iterator_variable()
|
||||||
|
|
||||||
if not variable:
|
if self._is_empty_iteration(variable):
|
||||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
yield from self._handle_empty_iteration(variable)
|
||||||
|
|
||||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
|
||||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
|
||||||
|
|
||||||
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
|
|
||||||
# Try our best to preserve the type informat.
|
|
||||||
if isinstance(variable, ArraySegment):
|
|
||||||
output = variable.model_copy(update={"value": []})
|
|
||||||
else:
|
|
||||||
output = ArrayAnySegment(value=[])
|
|
||||||
yield StreamCompletedEvent(
|
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
||||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
|
||||||
# from graph definition?
|
|
||||||
outputs={"output": output},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
iterator_list_value = variable.to_object()
|
iterator_list_value = self._validate_and_get_iterator_list(variable)
|
||||||
|
|
||||||
if not isinstance(iterator_list_value, list):
|
|
||||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
|
||||||
|
|
||||||
inputs = {"iterator_selector": iterator_list_value}
|
inputs = {"iterator_selector": iterator_list_value}
|
||||||
|
|
||||||
if not self._node_data.start_node_id:
|
self._validate_start_node()
|
||||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
|
||||||
|
|
||||||
started_at = naive_utc_now()
|
started_at = naive_utc_now()
|
||||||
iter_run_map: dict[str, float] = {}
|
iter_run_map: dict[str, float] = {}
|
||||||
outputs: list[Any] = []
|
outputs: list[object] = []
|
||||||
|
|
||||||
yield IterationStartedEvent(
|
yield IterationStartedEvent(
|
||||||
start_at=started_at,
|
start_at=started_at,
|
||||||
@ -137,6 +120,86 @@ class IterationNode(Node):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
yield from self._execute_iterations(
|
||||||
|
iterator_list_value=iterator_list_value,
|
||||||
|
outputs=outputs,
|
||||||
|
iter_run_map=iter_run_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield from self._handle_iteration_success(
|
||||||
|
started_at=started_at,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
iterator_list_value=iterator_list_value,
|
||||||
|
iter_run_map=iter_run_map,
|
||||||
|
)
|
||||||
|
except IterationNodeError as e:
|
||||||
|
yield from self._handle_iteration_failure(
|
||||||
|
started_at=started_at,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
iterator_list_value=iterator_list_value,
|
||||||
|
iter_run_map=iter_run_map,
|
||||||
|
error=e,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||||
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||||
|
|
||||||
|
if not variable:
|
||||||
|
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||||
|
|
||||||
|
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||||
|
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||||
|
|
||||||
|
return variable
|
||||||
|
|
||||||
|
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
|
||||||
|
return isinstance(variable, NoneSegment) or len(variable.value) == 0
|
||||||
|
|
||||||
|
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
|
||||||
|
# Try our best to preserve the type information.
|
||||||
|
if isinstance(variable, ArraySegment):
|
||||||
|
output = variable.model_copy(update={"value": []})
|
||||||
|
else:
|
||||||
|
output = ArrayAnySegment(value=[])
|
||||||
|
|
||||||
|
yield StreamCompletedEvent(
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||||
|
# from graph definition?
|
||||||
|
outputs={"output": output},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
|
||||||
|
iterator_list_value = variable.to_object()
|
||||||
|
|
||||||
|
if not isinstance(iterator_list_value, list):
|
||||||
|
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||||
|
|
||||||
|
return cast(list[object], iterator_list_value)
|
||||||
|
|
||||||
|
def _validate_start_node(self) -> None:
|
||||||
|
if not self._node_data.start_node_id:
|
||||||
|
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||||
|
|
||||||
|
def _execute_iterations(
|
||||||
|
self,
|
||||||
|
iterator_list_value: Sequence[object],
|
||||||
|
outputs: list[object],
|
||||||
|
iter_run_map: dict[str, float],
|
||||||
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||||
|
if self._node_data.is_parallel:
|
||||||
|
# Parallel mode execution
|
||||||
|
yield from self._execute_parallel_iterations(
|
||||||
|
iterator_list_value=iterator_list_value,
|
||||||
|
outputs=outputs,
|
||||||
|
iter_run_map=iter_run_map,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Sequential mode execution
|
||||||
for index, item in enumerate(iterator_list_value):
|
for index, item in enumerate(iterator_list_value):
|
||||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
yield IterationNextEvent(index=index)
|
yield IterationNextEvent(index=index)
|
||||||
@ -154,45 +217,146 @@ class IterationNode(Node):
|
|||||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||||
|
|
||||||
yield IterationSucceededEvent(
|
def _execute_parallel_iterations(
|
||||||
start_at=started_at,
|
self,
|
||||||
inputs=inputs,
|
iterator_list_value: Sequence[object],
|
||||||
outputs={"output": outputs},
|
outputs: list[object],
|
||||||
steps=len(iterator_list_value),
|
iter_run_map: dict[str, float],
|
||||||
metadata={
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
# Initialize outputs list with None values to maintain order
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
outputs.extend([None] * len(iterator_list_value))
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Yield final success event
|
# Determine the number of parallel workers
|
||||||
yield StreamCompletedEvent(
|
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
outputs={"output": outputs},
|
# Submit all iteration tasks
|
||||||
metadata={
|
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
for index, item in enumerate(iterator_list_value):
|
||||||
},
|
yield IterationNextEvent(index=index)
|
||||||
|
future = executor.submit(
|
||||||
|
self._execute_single_iteration_parallel,
|
||||||
|
index=index,
|
||||||
|
item=item,
|
||||||
)
|
)
|
||||||
)
|
future_to_index[future] = index
|
||||||
except IterationNodeError as e:
|
|
||||||
yield IterationFailedEvent(
|
# Process completed iterations as they finish
|
||||||
start_at=started_at,
|
for future in as_completed(future_to_index):
|
||||||
inputs=inputs,
|
index = future_to_index[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
iter_start_at, events, output_value, tokens_used = result
|
||||||
|
|
||||||
|
# Update outputs at the correct index
|
||||||
|
outputs[index] = output_value
|
||||||
|
|
||||||
|
# Yield all events from this iteration
|
||||||
|
yield from events
|
||||||
|
|
||||||
|
# Update tokens and timing
|
||||||
|
self.graph_runtime_state.total_tokens += tokens_used
|
||||||
|
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Handle errors based on error_handle_mode
|
||||||
|
match self._node_data.error_handle_mode:
|
||||||
|
case ErrorHandleMode.TERMINATED:
|
||||||
|
# Cancel remaining futures and re-raise
|
||||||
|
for f in future_to_index:
|
||||||
|
if f != future:
|
||||||
|
f.cancel()
|
||||||
|
raise IterationNodeError(str(e))
|
||||||
|
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||||
|
outputs[index] = None
|
||||||
|
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||||
|
outputs[index] = None # Will be filtered later
|
||||||
|
|
||||||
|
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||||
|
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||||
|
outputs[:] = [output for output in outputs if output is not None]
|
||||||
|
|
||||||
|
def _execute_single_iteration_parallel(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
item: object,
|
||||||
|
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||||
|
"""Execute a single iteration in parallel mode and return results."""
|
||||||
|
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
events: list[GraphNodeEventBase] = []
|
||||||
|
outputs_temp: list[object] = []
|
||||||
|
|
||||||
|
graph_engine = self._create_graph_engine(index, item)
|
||||||
|
|
||||||
|
# Collect events instead of yielding them directly
|
||||||
|
for event in self._run_single_iter(
|
||||||
|
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||||
|
outputs=outputs_temp,
|
||||||
|
graph_engine=graph_engine,
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
# Get the output value from the temporary outputs list
|
||||||
|
output_value = outputs_temp[0] if outputs_temp else None
|
||||||
|
|
||||||
|
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||||
|
|
||||||
|
def _handle_iteration_success(
|
||||||
|
self,
|
||||||
|
started_at: datetime,
|
||||||
|
inputs: dict[str, Sequence[object]],
|
||||||
|
outputs: list[object],
|
||||||
|
iterator_list_value: Sequence[object],
|
||||||
|
iter_run_map: dict[str, float],
|
||||||
|
) -> Generator[NodeEventBase, None, None]:
|
||||||
|
yield IterationSucceededEvent(
|
||||||
|
start_at=started_at,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs={"output": outputs},
|
||||||
|
steps=len(iterator_list_value),
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield final success event
|
||||||
|
yield StreamCompletedEvent(
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={"output": outputs},
|
outputs={"output": outputs},
|
||||||
steps=len(iterator_list_value),
|
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
|
||||||
},
|
},
|
||||||
error=str(e),
|
|
||||||
)
|
)
|
||||||
yield StreamCompletedEvent(
|
)
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
def _handle_iteration_failure(
|
||||||
error=str(e),
|
self,
|
||||||
)
|
started_at: datetime,
|
||||||
|
inputs: dict[str, Sequence[object]],
|
||||||
|
outputs: list[object],
|
||||||
|
iterator_list_value: Sequence[object],
|
||||||
|
iter_run_map: dict[str, float],
|
||||||
|
error: IterationNodeError,
|
||||||
|
) -> Generator[NodeEventBase, None, None]:
|
||||||
|
yield IterationFailedEvent(
|
||||||
|
start_at=started_at,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs={"output": outputs},
|
||||||
|
steps=len(iterator_list_value),
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||||
|
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||||
|
},
|
||||||
|
error=str(error),
|
||||||
|
)
|
||||||
|
yield StreamCompletedEvent(
|
||||||
|
node_run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=str(error),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
@ -305,9 +469,9 @@ class IterationNode(Node):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
outputs: list,
|
outputs: list[object],
|
||||||
graph_engine: "GraphEngine",
|
graph_engine: "GraphEngine",
|
||||||
) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]:
|
) -> Generator[GraphNodeEventBase, None, None]:
|
||||||
rst = graph_engine.run()
|
rst = graph_engine.run()
|
||||||
# get current iteration index
|
# get current iteration index
|
||||||
index_variable = variable_pool.get([self._node_id, "index"])
|
index_variable = variable_pool.get([self._node_id, "index"])
|
||||||
@ -338,7 +502,7 @@ class IterationNode(Node):
|
|||||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _create_graph_engine(self, index: int, item: Any):
|
def _create_graph_engine(self, index: int, item: object):
|
||||||
# Import dependencies
|
# Import dependencies
|
||||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
@ -387,18 +551,9 @@ class IterationNode(Node):
|
|||||||
|
|
||||||
# Create a new GraphEngine for this iteration
|
# Create a new GraphEngine for this iteration
|
||||||
graph_engine = GraphEngine(
|
graph_engine = GraphEngine(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_id,
|
|
||||||
workflow_id=self.workflow_id,
|
workflow_id=self.workflow_id,
|
||||||
user_id=self.user_id,
|
|
||||||
user_from=self.user_from,
|
|
||||||
invoke_from=self.invoke_from,
|
|
||||||
call_depth=self.workflow_call_depth,
|
|
||||||
graph=iteration_graph,
|
graph=iteration_graph,
|
||||||
graph_config=self.graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
max_execution_steps=10000, # Use default or config value
|
|
||||||
max_execution_time=600, # Use default or config value
|
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from sqlalchemy import Float, and_, func, or_, select, text
|
from sqlalchemy import Float, and_, func, or_, select, text
|
||||||
from sqlalchemy import cast as sqlalchemy_cast
|
from sqlalchemy import cast as sqlalchemy_cast
|
||||||
@ -568,7 +568,7 @@ class KnowledgeRetrievalNode(Node):
|
|||||||
return automatic_metadata_filters
|
return automatic_metadata_filters
|
||||||
|
|
||||||
def _process_metadata_filter_func(
|
def _process_metadata_filter_func(
|
||||||
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list[Any]
|
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
if value is None and condition not in ("empty", "not empty"):
|
if value is None and condition not in ("empty", "not empty"):
|
||||||
return filters
|
return filters
|
||||||
|
|||||||
@ -959,7 +959,7 @@ class LLMNode(Node):
|
|||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: dict | None = None):
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
"type": "llm",
|
"type": "llm",
|
||||||
"config": {
|
"config": {
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Any, Literal, Optional
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class LoopNodeData(BaseLoopNodeData):
|
|||||||
loop_count: int # Maximum number of loops
|
loop_count: int # Maximum number of loops
|
||||||
break_conditions: list[Condition] # Conditions to break the loop
|
break_conditions: list[Condition] # Conditions to break the loop
|
||||||
logical_operator: Literal["and", "or"]
|
logical_operator: Literal["and", "or"]
|
||||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
|
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
@field_validator("outputs", mode="before")
|
@field_validator("outputs", mode="before")
|
||||||
@ -74,7 +74,7 @@ class LoopState(BaseLoopState):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
outputs: list[Any] = Field(default_factory=list)
|
outputs: list[Any] = Field(default_factory=list)
|
||||||
current_output: Any | None = None
|
current_output: Any = None
|
||||||
|
|
||||||
class MetaData(BaseLoopState.MetaData):
|
class MetaData(BaseLoopState.MetaData):
|
||||||
"""
|
"""
|
||||||
@ -83,7 +83,7 @@ class LoopState(BaseLoopState):
|
|||||||
|
|
||||||
loop_length: int
|
loop_length: int
|
||||||
|
|
||||||
def get_last_output(self) -> Any | None:
|
def get_last_output(self) -> Any:
|
||||||
"""
|
"""
|
||||||
Get last output.
|
Get last output.
|
||||||
"""
|
"""
|
||||||
@ -91,7 +91,7 @@ class LoopState(BaseLoopState):
|
|||||||
return self.outputs[-1]
|
return self.outputs[-1]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_current_output(self) -> Any | None:
|
def get_current_output(self) -> Any:
|
||||||
"""
|
"""
|
||||||
Get current output.
|
Get current output.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.variables import Segment, SegmentType
|
from core.variables import Segment, SegmentType
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
ErrorStrategy,
|
||||||
@ -444,18 +443,9 @@ class LoopNode(Node):
|
|||||||
|
|
||||||
# Create a new GraphEngine for this iteration
|
# Create a new GraphEngine for this iteration
|
||||||
graph_engine = GraphEngine(
|
graph_engine = GraphEngine(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_id,
|
|
||||||
workflow_id=self.workflow_id,
|
workflow_id=self.workflow_id,
|
||||||
user_id=self.user_id,
|
|
||||||
user_from=self.user_from,
|
|
||||||
invoke_from=self.invoke_from,
|
|
||||||
call_depth=self.workflow_call_depth,
|
|
||||||
graph=loop_graph,
|
graph=loop_graph,
|
||||||
graph_config=self.graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
|
||||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -118,7 +118,7 @@ class ParameterExtractorNode(Node):
|
|||||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: dict | None = None):
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
"model": {
|
"model": {
|
||||||
"prompt_templates": {
|
"prompt_templates": {
|
||||||
|
|||||||
@ -271,7 +271,7 @@ class QuestionClassifierNode(Node):
|
|||||||
return variable_mapping
|
return variable_mapping
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: dict | None = None):
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
Get default config of node.
|
Get default config of node.
|
||||||
:param filters: filter by node config parameters (not used in this implementation).
|
:param filters: filter by node config parameters (not used in this implementation).
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class TemplateTransformNode(Node):
|
|||||||
return self._node_data
|
return self._node_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: dict | None = None):
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
Get default config of node.
|
Get default config of node.
|
||||||
:param filters: filter by node config parameters.
|
:param filters: filter by node config parameters.
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from core.workflow.enums import (
|
|||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
)
|
)
|
||||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
@ -55,7 +55,7 @@ class ToolNode(Node):
|
|||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||||
"""
|
"""
|
||||||
Run the tool node
|
Run the tool node
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -18,7 +18,7 @@ class VariableOperationItem(BaseModel):
|
|||||||
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
|
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
|
||||||
# 3. During the variable updating procedure: The `value` field is reassigned to hold
|
# 3. During the variable updating procedure: The `value` field is reassigned to hold
|
||||||
# the resolved actual value that will be applied to the target variable.
|
# the resolved actual value that will be applied to the target variable.
|
||||||
value: Any | None = None
|
value: Any = None
|
||||||
|
|
||||||
|
|
||||||
class VariableAssignerNodeData(BaseNodeData):
|
class VariableAssignerNodeData(BaseNodeData):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.exc import GenerateTaskStoppedError
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
@ -43,7 +43,7 @@ class WorkflowEntry:
|
|||||||
call_depth: int,
|
call_depth: int,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
command_channel: Optional[CommandChannel] = None,
|
command_channel: CommandChannel | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Init workflow entry
|
Init workflow entry
|
||||||
@ -73,18 +73,9 @@ class WorkflowEntry:
|
|||||||
|
|
||||||
self.command_channel = command_channel
|
self.command_channel = command_channel
|
||||||
self.graph_engine = GraphEngine(
|
self.graph_engine = GraphEngine(
|
||||||
tenant_id=tenant_id,
|
|
||||||
app_id=app_id,
|
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
user_id=user_id,
|
|
||||||
user_from=user_from,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
call_depth=call_depth,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
|
||||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,9 @@ class WorkflowRuntimeTypeConverter:
|
|||||||
|
|
||||||
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||||
result = self._to_json_encodable_recursive(value)
|
result = self._to_json_encodable_recursive(value)
|
||||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
if isinstance(result, Mapping) or result is None:
|
||||||
|
return result
|
||||||
|
return {}
|
||||||
|
|
||||||
def _to_json_encodable_recursive(self, value: Any):
|
def _to_json_encodable_recursive(self, value: Any):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|||||||
@ -846,7 +846,7 @@ class Conversation(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app(self) -> Optional[App]:
|
def app(self) -> App | None:
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
return session.query(App).where(App.id == self.app_id).first()
|
return session.query(App).where(App.id == self.app_id).first()
|
||||||
|
|
||||||
@ -1140,7 +1140,7 @@ class Message(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retriever_resources(self) -> Any | list[Any]:
|
def retriever_resources(self) -> Any:
|
||||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@ -314,11 +315,11 @@ class MCPToolProvider(Base):
|
|||||||
return [MCPTool(**tool) for tool in json.loads(self.tools)]
|
return [MCPTool(**tool) for tool in json.loads(self.tools)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_icon(self) -> dict[str, str] | str:
|
def provider_icon(self) -> Mapping[str, str] | str:
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return cast(dict[str, str], json.loads(self.icon))
|
return json.loads(self.icon)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return file_helpers.get_signed_file_url(self.icon)
|
return file_helpers.get_signed_file_url(self.icon)
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@
|
|||||||
"core/ops",
|
"core/ops",
|
||||||
"core/tools",
|
"core/tools",
|
||||||
"core/model_runtime",
|
"core/model_runtime",
|
||||||
"core/workflow",
|
"core/workflow/nodes",
|
||||||
"core/app/app_config/easy_ui_based_app/dataset"
|
"core/app/app_config/easy_ui_based_app/dataset"
|
||||||
],
|
],
|
||||||
"typeCheckingMode": "strict",
|
"typeCheckingMode": "strict",
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Union, cast
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
@ -39,7 +40,9 @@ class ToolTransformService:
|
|||||||
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
|
def get_tool_provider_icon_url(
|
||||||
|
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
|
||||||
|
) -> str | Mapping[str, str]:
|
||||||
"""
|
"""
|
||||||
get tool provider icon url
|
get tool provider icon url
|
||||||
"""
|
"""
|
||||||
@ -52,7 +55,7 @@ class ToolTransformService:
|
|||||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||||
try:
|
try:
|
||||||
if isinstance(icon, str):
|
if isinstance(icon, str):
|
||||||
return cast(dict, json.loads(icon))
|
return json.loads(icon)
|
||||||
return icon
|
return icon
|
||||||
except Exception:
|
except Exception:
|
||||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||||
@ -119,7 +122,7 @@ class ToolTransformService:
|
|||||||
name=provider_controller.entity.identity.name,
|
name=provider_controller.entity.identity.name,
|
||||||
description=provider_controller.entity.identity.description,
|
description=provider_controller.entity.identity.description,
|
||||||
icon=provider_controller.entity.identity.icon,
|
icon=provider_controller.entity.identity.icon,
|
||||||
icon_dark=provider_controller.entity.identity.icon_dark,
|
icon_dark=provider_controller.entity.identity.icon_dark or "",
|
||||||
label=provider_controller.entity.identity.label,
|
label=provider_controller.entity.identity.label,
|
||||||
type=ToolProviderType.BUILT_IN,
|
type=ToolProviderType.BUILT_IN,
|
||||||
masked_credentials={},
|
masked_credentials={},
|
||||||
@ -141,9 +144,10 @@ class ToolTransformService:
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
masked_creds = {}
|
||||||
for name in schema:
|
for name in schema:
|
||||||
if result.masked_credentials:
|
masked_creds[name] = ""
|
||||||
result.masked_credentials[name] = ""
|
result.masked_credentials = masked_creds
|
||||||
|
|
||||||
# check if the provider need credentials
|
# check if the provider need credentials
|
||||||
if not provider_controller.need_credentials:
|
if not provider_controller.need_credentials:
|
||||||
@ -221,7 +225,7 @@ class ToolTransformService:
|
|||||||
name=provider_controller.entity.identity.name,
|
name=provider_controller.entity.identity.name,
|
||||||
description=provider_controller.entity.identity.description,
|
description=provider_controller.entity.identity.description,
|
||||||
icon=provider_controller.entity.identity.icon,
|
icon=provider_controller.entity.identity.icon,
|
||||||
icon_dark=provider_controller.entity.identity.icon_dark,
|
icon_dark=provider_controller.entity.identity.icon_dark or "",
|
||||||
label=provider_controller.entity.identity.label,
|
label=provider_controller.entity.identity.label,
|
||||||
type=ToolProviderType.WORKFLOW,
|
type=ToolProviderType.WORKFLOW,
|
||||||
masked_credentials={},
|
masked_credentials={},
|
||||||
@ -334,7 +338,7 @@ class ToolTransformService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_tool_entity_to_api_entity(
|
def convert_tool_entity_to_api_entity(
|
||||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
tool: ApiToolBundle | WorkflowTool | Tool,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
labels: list[str] | None = None,
|
labels: list[str] | None = None,
|
||||||
) -> ToolApiEntity:
|
) -> ToolApiEntity:
|
||||||
@ -388,7 +392,7 @@ class ToolTransformService:
|
|||||||
parameters=merged_parameters,
|
parameters=merged_parameters,
|
||||||
labels=labels or [],
|
labels=labels or [],
|
||||||
)
|
)
|
||||||
elif isinstance(tool, ApiToolBundle):
|
else:
|
||||||
return ToolApiEntity(
|
return ToolApiEntity(
|
||||||
author=tool.author,
|
author=tool.author,
|
||||||
name=tool.operation_id or "",
|
name=tool.operation_id or "",
|
||||||
@ -397,9 +401,6 @@ class ToolTransformService:
|
|||||||
parameters=tool.parameters,
|
parameters=tool.parameters,
|
||||||
labels=labels or [],
|
labels=labels or [],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Handle WorkflowTool case
|
|
||||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_builtin_provider_to_credential_entity(
|
def convert_builtin_provider_to_credential_entity(
|
||||||
|
|||||||
@ -563,12 +563,12 @@ class WorkflowService:
|
|||||||
# This will prevent validation errors from breaking the workflow
|
# This will prevent validation errors from breaking the workflow
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_default_block_configs(self) -> list[dict]:
|
def get_default_block_configs(self) -> Sequence[Mapping[str, object]]:
|
||||||
"""
|
"""
|
||||||
Get default block configs
|
Get default block configs
|
||||||
"""
|
"""
|
||||||
# return default block config
|
# return default block config
|
||||||
default_block_configs = []
|
default_block_configs: list[Mapping[str, object]] = []
|
||||||
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
||||||
node_class = node_class_mapping[LATEST_VERSION]
|
node_class = node_class_mapping[LATEST_VERSION]
|
||||||
default_config = node_class.get_default_config()
|
default_config = node_class.get_default_config()
|
||||||
@ -577,7 +577,9 @@ class WorkflowService:
|
|||||||
|
|
||||||
return default_block_configs
|
return default_block_configs
|
||||||
|
|
||||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None:
|
def get_default_block_config(
|
||||||
|
self, node_type: str, filters: Mapping[str, object] | None = None
|
||||||
|
) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
Get default config of node.
|
Get default config of node.
|
||||||
:param node_type: node type
|
:param node_type: node type
|
||||||
@ -588,12 +590,12 @@ class WorkflowService:
|
|||||||
|
|
||||||
# return default block config
|
# return default block config
|
||||||
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||||
return None
|
return {}
|
||||||
|
|
||||||
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||||
default_config = node_class.get_default_config(filters=filters)
|
default_config = node_class.get_default_config(filters=filters)
|
||||||
if not default_config:
|
if not default_config:
|
||||||
return None
|
return {}
|
||||||
|
|
||||||
return default_config
|
return default_config
|
||||||
|
|
||||||
@ -807,11 +809,13 @@ class WorkflowService:
|
|||||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
)
|
)
|
||||||
error = node_run_result.error if not run_succeeded else None
|
error = node_run_result.error if not run_succeeded else None
|
||||||
|
|
||||||
return node, node_run_result, run_succeeded, error
|
return node, node_run_result, run_succeeded, error
|
||||||
|
|
||||||
except WorkflowNodeRunFailedError as e:
|
except WorkflowNodeRunFailedError as e:
|
||||||
return e._node, None, False, e._error
|
node = e.node
|
||||||
|
run_succeeded = False
|
||||||
|
node_run_result = None
|
||||||
|
error = e.error
|
||||||
|
return node, node_run_result, run_succeeded, error
|
||||||
|
|
||||||
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
|
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
|
||||||
"""Apply error strategy when node execution fails."""
|
"""Apply error strategy when node execution fails."""
|
||||||
|
|||||||
@ -89,6 +89,7 @@ def test_execute_code(setup_code_executor_mock):
|
|||||||
code_config = {
|
code_config = {
|
||||||
"id": "code",
|
"id": "code",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "code",
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"result": {
|
"result": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
@ -135,6 +136,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
|||||||
code_config = {
|
code_config = {
|
||||||
"id": "code",
|
"id": "code",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "code",
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"result": {
|
"result": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@ -180,6 +182,7 @@ def test_execute_code_output_validator_depth():
|
|||||||
code_config = {
|
code_config = {
|
||||||
"id": "code",
|
"id": "code",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "code",
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"string_validator": {
|
"string_validator": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@ -298,6 +301,7 @@ def test_execute_code_output_object_list():
|
|||||||
code_config = {
|
code_config = {
|
||||||
"id": "code",
|
"id": "code",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "code",
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"object_list": {
|
"object_list": {
|
||||||
"type": "array[object]",
|
"type": "array[object]",
|
||||||
@ -358,7 +362,8 @@ def test_execute_code_output_object_list():
|
|||||||
node._transform_result(result, node._node_data.outputs)
|
node._transform_result(result, node._node_data.outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_execute_code_scientific_notation():
|
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||||
|
def test_execute_code_scientific_notation(setup_code_executor_mock):
|
||||||
code = """
|
code = """
|
||||||
def main():
|
def main():
|
||||||
return {
|
return {
|
||||||
@ -370,6 +375,7 @@ def test_execute_code_scientific_notation():
|
|||||||
code_config = {
|
code_config = {
|
||||||
"id": "code",
|
"id": "code",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "code",
|
||||||
"outputs": {
|
"outputs": {
|
||||||
"result": {
|
"result": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
|
|||||||
@ -77,6 +77,7 @@ def test_get(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -110,6 +111,7 @@ def test_no_auth(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -139,6 +141,7 @@ def test_custom_authorization_header(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -231,6 +234,7 @@ def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -271,6 +275,7 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -310,6 +315,7 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -343,6 +349,7 @@ def test_template(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -378,6 +385,7 @@ def test_json(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "post",
|
"method": "post",
|
||||||
@ -420,6 +428,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "post",
|
"method": "post",
|
||||||
@ -467,6 +476,7 @@ def test_form_data(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "post",
|
"method": "post",
|
||||||
@ -517,6 +527,7 @@ def test_none_data(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "post",
|
"method": "post",
|
||||||
@ -550,6 +561,7 @@ def test_mock_404(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -579,6 +591,7 @@ def test_multi_colons_parse(setup_http_mock):
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
@ -635,6 +648,7 @@ def test_nested_object_variable_selector(setup_http_mock):
|
|||||||
{
|
{
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
"title": "http",
|
"title": "http",
|
||||||
"desc": "",
|
"desc": "",
|
||||||
"method": "get",
|
"method": "get",
|
||||||
|
|||||||
@ -20,6 +20,7 @@ def test_execute_code(setup_code_executor_mock):
|
|||||||
config = {
|
config = {
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "template-transform",
|
||||||
"title": "123",
|
"title": "123",
|
||||||
"variables": [
|
"variables": [
|
||||||
{
|
{
|
||||||
|
|||||||
@ -70,6 +70,7 @@ def test_tool_variable_invoke():
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "tool",
|
||||||
"title": "a",
|
"title": "a",
|
||||||
"desc": "a",
|
"desc": "a",
|
||||||
"provider_id": "time",
|
"provider_id": "time",
|
||||||
@ -101,6 +102,7 @@ def test_tool_mixed_invoke():
|
|||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": {
|
"data": {
|
||||||
|
"type": "tool",
|
||||||
"title": "a",
|
"title": "a",
|
||||||
"desc": "a",
|
"desc": "a",
|
||||||
"provider_id": "time",
|
"provider_id": "time",
|
||||||
|
|||||||
@ -454,7 +454,7 @@ class TestToolTransformService:
|
|||||||
name=fake.company(),
|
name=fake.company(),
|
||||||
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
||||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||||
icon_dark=None,
|
icon_dark="",
|
||||||
label=I18nObject(en_US=fake.company()),
|
label=I18nObject(en_US=fake.company()),
|
||||||
type=ToolProviderType.API,
|
type=ToolProviderType.API,
|
||||||
masked_credentials={},
|
masked_credentials={},
|
||||||
@ -473,8 +473,8 @@ class TestToolTransformService:
|
|||||||
assert provider.icon["background"] == "#FF6B6B"
|
assert provider.icon["background"] == "#FF6B6B"
|
||||||
assert provider.icon["content"] == "🔧"
|
assert provider.icon["content"] == "🔧"
|
||||||
|
|
||||||
# Verify dark icon remains None
|
# Verify dark icon remains empty string
|
||||||
assert provider.icon_dark is None
|
assert provider.icon_dark == ""
|
||||||
|
|
||||||
def test_builtin_provider_to_user_provider_success(
|
def test_builtin_provider_to_user_provider_success(
|
||||||
self, db_session_with_containers, mock_external_service_dependencies
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
@ -628,7 +628,7 @@ class TestToolTransformService:
|
|||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.is_team_authorization is True
|
assert result.is_team_authorization is True
|
||||||
assert result.allow_delete is False
|
assert result.allow_delete is False
|
||||||
assert result.masked_credentials == {}
|
assert result.masked_credentials == {"api_key": ""}
|
||||||
|
|
||||||
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
|
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -17,7 +17,6 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
|||||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||||
parameters=[],
|
parameters=[],
|
||||||
description=None,
|
description=None,
|
||||||
output_schema=None,
|
|
||||||
has_runtime_parameters=False,
|
has_runtime_parameters=False,
|
||||||
)
|
)
|
||||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||||
|
|||||||
@ -95,17 +95,3 @@ class TestGraphRuntimeState:
|
|||||||
# Test add_tokens validation
|
# Test add_tokens validation
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
state.add_tokens(-1)
|
state.add_tokens(-1)
|
||||||
|
|
||||||
def test_deep_copy_for_nested_objects(self):
|
|
||||||
variable_pool = VariablePool()
|
|
||||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
|
||||||
|
|
||||||
# Test deep copy for nested dict
|
|
||||||
nested_data = {"level1": {"level2": {"value": "test"}}}
|
|
||||||
state.set_output("nested", nested_data)
|
|
||||||
|
|
||||||
retrieved = state.get_output("nested")
|
|
||||||
retrieved["level1"]["level2"]["value"] = "modified"
|
|
||||||
|
|
||||||
# Original should remain unchanged
|
|
||||||
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"
|
|
||||||
|
|||||||
@ -3,14 +3,12 @@
|
|||||||
import time
|
import time
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
|
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
|
|
||||||
def test_abort_command():
|
def test_abort_command():
|
||||||
@ -42,18 +40,9 @@ def test_abort_command():
|
|||||||
|
|
||||||
# Create GraphEngine with same shared runtime state
|
# Create GraphEngine with same shared runtime state
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test",
|
|
||||||
app_id="test",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
|
||||||
call_depth=0,
|
|
||||||
graph=mock_graph,
|
graph=mock_graph,
|
||||||
graph_config={},
|
|
||||||
graph_runtime_state=shared_runtime_state, # Use shared instance
|
graph_runtime_state=shared_runtime_state, # Use shared instance
|
||||||
max_execution_steps=100,
|
|
||||||
max_execution_time=10,
|
|
||||||
command_channel=command_channel,
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@ This test validates that:
|
|||||||
- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output)
|
- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
@ -16,7 +15,6 @@ from core.workflow.graph_events import (
|
|||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
from .test_table_runner import TableTestRunner
|
from .test_table_runner import TableTestRunner
|
||||||
|
|
||||||
@ -40,23 +38,11 @@ def test_streaming_output_with_blocking_equals_one():
|
|||||||
use_mock_factory=True,
|
use_mock_factory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_config = fixture_data.get("workflow", {})
|
|
||||||
graph_config = workflow_config.get("graph", {})
|
|
||||||
|
|
||||||
# Create and run the engine
|
# Create and run the engine
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=500,
|
|
||||||
max_execution_time=30,
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,23 +133,11 @@ def test_streaming_output_with_blocking_not_equals_one():
|
|||||||
use_mock_factory=True,
|
use_mock_factory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_config = fixture_data.get("workflow", {})
|
|
||||||
graph_config = workflow_config.get("graph", {})
|
|
||||||
|
|
||||||
# Create and run the engine
|
# Create and run the engine
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=500,
|
|
||||||
max_execution_time=30,
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import contextvars
|
|||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from flask import Flask, g
|
from flask import Flask, g
|
||||||
@ -59,7 +58,7 @@ class TestContextPreservation:
|
|||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
# Variable to store value from worker
|
# Variable to store value from worker
|
||||||
worker_value: Optional[str] = None
|
worker_value: str | None = None
|
||||||
|
|
||||||
def worker_task() -> None:
|
def worker_task() -> None:
|
||||||
nonlocal worker_value
|
nonlocal worker_value
|
||||||
@ -120,7 +119,7 @@ class TestContextPreservation:
|
|||||||
test_node = MagicMock(spec=Node)
|
test_node = MagicMock(spec=Node)
|
||||||
|
|
||||||
# Variable to capture context inside node execution
|
# Variable to capture context inside node execution
|
||||||
captured_value: Optional[str] = None
|
captured_value: str | None = None
|
||||||
context_available_in_node = False
|
context_available_in_node = False
|
||||||
|
|
||||||
def mock_run() -> list[GraphNodeEventBase]:
|
def mock_run() -> list[GraphNodeEventBase]:
|
||||||
|
|||||||
@ -10,11 +10,9 @@ import time
|
|||||||
from hypothesis import HealthCheck, given, settings
|
from hypothesis import HealthCheck, given, settings
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
|
from core.workflow.graph_events import GraphRunStartedEvent, GraphRunSucceededEvent
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
# Import the test framework from the new module
|
# Import the test framework from the new module
|
||||||
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase
|
||||||
@ -460,18 +458,9 @@ def test_layer_system_basic():
|
|||||||
|
|
||||||
# Create engine with layer
|
# Create engine with layer
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=300,
|
|
||||||
max_execution_time=60,
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -525,18 +514,9 @@ def test_layer_chaining():
|
|||||||
|
|
||||||
# Create engine
|
# Create engine
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=300,
|
|
||||||
max_execution_time=60,
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -581,18 +561,9 @@ def test_layer_error_handling():
|
|||||||
|
|
||||||
# Create engine with faulty layer
|
# Create engine with faulty layer
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=fixture_data.get("workflow", {}).get("graph", {}),
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=300,
|
|
||||||
max_execution_time=60,
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,194 @@
|
|||||||
|
"""Unit tests for GraphExecution serialization helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections import deque
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||||
|
from core.workflow.graph_engine.domain import GraphExecution
|
||||||
|
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||||
|
from core.workflow.graph_engine.response_coordinator.path import Path
|
||||||
|
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||||
|
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment
|
||||||
|
|
||||||
|
|
||||||
|
class CustomGraphExecutionError(Exception):
|
||||||
|
"""Custom exception used to verify error serialization."""
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_execution_serialization_round_trip() -> None:
|
||||||
|
"""GraphExecution serialization restores full aggregate state."""
|
||||||
|
# Arrange
|
||||||
|
execution = GraphExecution(workflow_id="wf-1")
|
||||||
|
execution.start()
|
||||||
|
node_a = execution.get_or_create_node_execution("node-a")
|
||||||
|
node_a.mark_started(execution_id="exec-1")
|
||||||
|
node_a.increment_retry()
|
||||||
|
node_a.mark_failed("boom")
|
||||||
|
node_b = execution.get_or_create_node_execution("node-b")
|
||||||
|
node_b.mark_skipped()
|
||||||
|
execution.fail(CustomGraphExecutionError("serialization failure"))
|
||||||
|
|
||||||
|
# Act
|
||||||
|
serialized = execution.dumps()
|
||||||
|
payload = json.loads(serialized)
|
||||||
|
restored = GraphExecution(workflow_id="wf-1")
|
||||||
|
restored.loads(serialized)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert payload["type"] == "GraphExecution"
|
||||||
|
assert payload["version"] == "1.0"
|
||||||
|
assert restored.workflow_id == "wf-1"
|
||||||
|
assert restored.started is True
|
||||||
|
assert restored.completed is True
|
||||||
|
assert restored.aborted is False
|
||||||
|
assert isinstance(restored.error, CustomGraphExecutionError)
|
||||||
|
assert str(restored.error) == "serialization failure"
|
||||||
|
assert set(restored.node_executions) == {"node-a", "node-b"}
|
||||||
|
restored_node_a = restored.node_executions["node-a"]
|
||||||
|
assert restored_node_a.state is NodeState.TAKEN
|
||||||
|
assert restored_node_a.retry_count == 1
|
||||||
|
assert restored_node_a.execution_id == "exec-1"
|
||||||
|
assert restored_node_a.error == "boom"
|
||||||
|
restored_node_b = restored.node_executions["node-b"]
|
||||||
|
assert restored_node_b.state is NodeState.SKIPPED
|
||||||
|
assert restored_node_b.retry_count == 0
|
||||||
|
assert restored_node_b.execution_id is None
|
||||||
|
assert restored_node_b.error is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_execution_loads_replaces_existing_state() -> None:
|
||||||
|
"""loads replaces existing runtime data with serialized snapshot."""
|
||||||
|
# Arrange
|
||||||
|
source = GraphExecution(workflow_id="wf-2")
|
||||||
|
source.start()
|
||||||
|
source_node = source.get_or_create_node_execution("node-source")
|
||||||
|
source_node.mark_taken()
|
||||||
|
serialized = source.dumps()
|
||||||
|
|
||||||
|
target = GraphExecution(workflow_id="wf-2")
|
||||||
|
target.start()
|
||||||
|
target.abort("pre-existing abort")
|
||||||
|
temp_node = target.get_or_create_node_execution("node-temp")
|
||||||
|
temp_node.increment_retry()
|
||||||
|
temp_node.mark_failed("temp error")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
target.loads(serialized)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert target.aborted is False
|
||||||
|
assert target.error is None
|
||||||
|
assert target.started is True
|
||||||
|
assert target.completed is False
|
||||||
|
assert set(target.node_executions) == {"node-source"}
|
||||||
|
restored_node = target.node_executions["node-source"]
|
||||||
|
assert restored_node.state is NodeState.TAKEN
|
||||||
|
assert restored_node.retry_count == 0
|
||||||
|
assert restored_node.execution_id is None
|
||||||
|
assert restored_node.error is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None:
|
||||||
|
"""ResponseStreamCoordinator serialization restores coordinator internals."""
|
||||||
|
|
||||||
|
template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])])
|
||||||
|
template_secondary = Template(segments=[TextSegment(text="secondary")])
|
||||||
|
|
||||||
|
class DummyNode:
|
||||||
|
def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None:
|
||||||
|
self.id = node_id
|
||||||
|
self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM
|
||||||
|
self.execution_type = execution_type
|
||||||
|
self.state = NodeState.UNKNOWN
|
||||||
|
self.title = node_id
|
||||||
|
self.template = template
|
||||||
|
|
||||||
|
def blocks_variable_output(self, *_args) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE)
|
||||||
|
response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE)
|
||||||
|
response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE)
|
||||||
|
source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE)
|
||||||
|
|
||||||
|
class DummyGraph:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.nodes = {
|
||||||
|
response_node1.id: response_node1,
|
||||||
|
response_node2.id: response_node2,
|
||||||
|
response_node3.id: response_node3,
|
||||||
|
source_node.id: source_node,
|
||||||
|
}
|
||||||
|
self.edges: dict[str, object] = {}
|
||||||
|
self.root_node = response_node1
|
||||||
|
|
||||||
|
def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised
|
||||||
|
return []
|
||||||
|
|
||||||
|
graph = DummyGraph()
|
||||||
|
|
||||||
|
def fake_from_node(cls, node: DummyNode) -> ResponseSession:
|
||||||
|
return ResponseSession(node_id=node.id, template=node.template)
|
||||||
|
|
||||||
|
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||||
|
|
||||||
|
coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||||
|
coordinator._response_nodes = {"response-1", "response-2", "response-3"}
|
||||||
|
coordinator._paths_maps = {
|
||||||
|
"response-1": [Path(edges=["edge-1"])],
|
||||||
|
"response-2": [Path(edges=[])],
|
||||||
|
"response-3": [Path(edges=["edge-2", "edge-3"])],
|
||||||
|
}
|
||||||
|
|
||||||
|
active_session = ResponseSession(node_id="response-1", template=response_node1.template)
|
||||||
|
active_session.index = 1
|
||||||
|
coordinator._active_session = active_session
|
||||||
|
waiting_session = ResponseSession(node_id="response-2", template=response_node2.template)
|
||||||
|
coordinator._waiting_sessions = deque([waiting_session])
|
||||||
|
pending_session = ResponseSession(node_id="response-3", template=response_node3.template)
|
||||||
|
pending_session.index = 2
|
||||||
|
coordinator._response_sessions = {"response-3": pending_session}
|
||||||
|
|
||||||
|
coordinator._node_execution_ids = {"response-1": "exec-1"}
|
||||||
|
event = NodeRunStreamChunkEvent(
|
||||||
|
id="exec-1",
|
||||||
|
node_id="response-1",
|
||||||
|
node_type=NodeType.ANSWER,
|
||||||
|
selector=["node-source", "text"],
|
||||||
|
chunk="chunk-1",
|
||||||
|
is_final=False,
|
||||||
|
)
|
||||||
|
coordinator._stream_buffers = {("node-source", "text"): [event]}
|
||||||
|
coordinator._stream_positions = {("node-source", "text"): 1}
|
||||||
|
coordinator._closed_streams = {("node-source", "text")}
|
||||||
|
|
||||||
|
serialized = coordinator.dumps()
|
||||||
|
|
||||||
|
restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type]
|
||||||
|
monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node))
|
||||||
|
restored.loads(serialized)
|
||||||
|
|
||||||
|
assert restored._response_nodes == {"response-1", "response-2", "response-3"}
|
||||||
|
assert restored._paths_maps["response-1"][0].edges == ["edge-1"]
|
||||||
|
assert restored._active_session is not None
|
||||||
|
assert restored._active_session.node_id == "response-1"
|
||||||
|
assert restored._active_session.index == 1
|
||||||
|
waiting_restored = list(restored._waiting_sessions)
|
||||||
|
assert len(waiting_restored) == 1
|
||||||
|
assert waiting_restored[0].node_id == "response-2"
|
||||||
|
assert waiting_restored[0].index == 0
|
||||||
|
assert set(restored._response_sessions) == {"response-3"}
|
||||||
|
assert restored._response_sessions["response-3"].index == 2
|
||||||
|
assert restored._node_execution_ids == {"response-1": "exec-1"}
|
||||||
|
assert ("node-source", "text") in restored._stream_buffers
|
||||||
|
restored_event = restored._stream_buffers[("node-source", "text")][0]
|
||||||
|
assert restored_event.chunk == "chunk-1"
|
||||||
|
assert restored._stream_positions[("node-source", "text")] == 1
|
||||||
|
assert ("node-source", "text") in restored._closed_streams
|
||||||
@ -7,7 +7,7 @@ the behavior of mock nodes during testing.
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
|
|
||||||
@ -18,9 +18,9 @@ class NodeMockConfig:
|
|||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
outputs: dict[str, Any] = field(default_factory=dict)
|
outputs: dict[str, Any] = field(default_factory=dict)
|
||||||
error: Optional[str] = None
|
error: str | None = None
|
||||||
delay: float = 0.0 # Simulated execution delay in seconds
|
delay: float = 0.0 # Simulated execution delay in seconds
|
||||||
custom_handler: Optional[Callable[..., dict[str, Any]]] = None
|
custom_handler: Callable[..., dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -51,7 +51,7 @@ class MockConfig:
|
|||||||
default_template_transform_response: str = "This is mocked template transform output"
|
default_template_transform_response: str = "This is mocked template transform output"
|
||||||
default_code_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked code execution result"})
|
default_code_response: dict[str, Any] = field(default_factory=lambda: {"result": "mocked code execution result"})
|
||||||
|
|
||||||
def get_node_config(self, node_id: str) -> Optional[NodeMockConfig]:
|
def get_node_config(self, node_id: str) -> NodeMockConfig | None:
|
||||||
"""Get configuration for a specific node."""
|
"""Get configuration for a specific node."""
|
||||||
return self.node_configs.get(node_id)
|
return self.node_configs.get(node_id)
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class MockNodeMixin:
|
|||||||
|
|
||||||
return default_outputs
|
return default_outputs
|
||||||
|
|
||||||
def _should_simulate_error(self) -> Optional[str]:
|
def _should_simulate_error(self) -> str | None:
|
||||||
"""Check if this node should simulate an error."""
|
"""Check if this node should simulate an error."""
|
||||||
if not self.mock_config:
|
if not self.mock_config:
|
||||||
return None
|
return None
|
||||||
@ -615,18 +615,9 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
|||||||
|
|
||||||
# Create a new GraphEngine for this iteration
|
# Create a new GraphEngine for this iteration
|
||||||
graph_engine = GraphEngine(
|
graph_engine = GraphEngine(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_id,
|
|
||||||
workflow_id=self.workflow_id,
|
workflow_id=self.workflow_id,
|
||||||
user_id=self.user_id,
|
|
||||||
user_from=self.user_from,
|
|
||||||
invoke_from=self.invoke_from,
|
|
||||||
call_depth=self.workflow_call_depth,
|
|
||||||
graph=iteration_graph,
|
graph=iteration_graph,
|
||||||
graph_config=self.graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
max_execution_steps=10000, # Use default or config value
|
|
||||||
max_execution_time=600, # Use default or config value
|
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -685,18 +676,9 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
|||||||
|
|
||||||
# Create a new GraphEngine for this iteration
|
# Create a new GraphEngine for this iteration
|
||||||
graph_engine = GraphEngine(
|
graph_engine = GraphEngine(
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_id,
|
|
||||||
workflow_id=self.workflow_id,
|
workflow_id=self.workflow_id,
|
||||||
user_id=self.user_id,
|
|
||||||
user_from=self.user_from,
|
|
||||||
invoke_from=self.invoke_from,
|
|
||||||
call_depth=self.workflow_call_depth,
|
|
||||||
graph=loop_graph,
|
graph=loop_graph,
|
||||||
graph_config=self.graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state_copy,
|
graph_runtime_state=graph_runtime_state_copy,
|
||||||
max_execution_steps=10000, # Use default or config value
|
|
||||||
max_execution_time=600, # Use default or config value
|
|
||||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -118,18 +118,9 @@ def test_parallel_streaming_workflow():
|
|||||||
|
|
||||||
# Create the graph engine
|
# Create the graph engine
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.WEB_APP,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=500,
|
|
||||||
max_execution_time=30,
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -17,10 +17,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.tools.utils.yaml_utils import _load_yaml_file
|
||||||
from core.tools.utils.yaml_utils import load_yaml_file
|
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
@ -42,7 +41,6 @@ from core.workflow.graph_events import (
|
|||||||
)
|
)
|
||||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
from .test_mock_config import MockConfig
|
from .test_mock_config import MockConfig
|
||||||
from .test_mock_factory import MockNodeFactory
|
from .test_mock_factory import MockNodeFactory
|
||||||
@ -60,14 +58,14 @@ class WorkflowTestCase:
|
|||||||
query: str = ""
|
query: str = ""
|
||||||
description: str = ""
|
description: str = ""
|
||||||
timeout: float = 30.0
|
timeout: float = 30.0
|
||||||
mock_config: Optional[MockConfig] = None
|
mock_config: MockConfig | None = None
|
||||||
use_auto_mock: bool = False
|
use_auto_mock: bool = False
|
||||||
expected_event_sequence: Optional[Sequence[type[GraphEngineEvent]]] = None
|
expected_event_sequence: Sequence[type[GraphEngineEvent]] | None = None
|
||||||
tags: list[str] = field(default_factory=list)
|
tags: list[str] = field(default_factory=list)
|
||||||
skip: bool = False
|
skip: bool = False
|
||||||
skip_reason: str = ""
|
skip_reason: str = ""
|
||||||
retry_count: int = 0
|
retry_count: int = 0
|
||||||
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None
|
custom_validator: Callable[[dict[str, Any]], bool] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -76,14 +74,14 @@ class WorkflowTestResult:
|
|||||||
|
|
||||||
test_case: WorkflowTestCase
|
test_case: WorkflowTestCase
|
||||||
success: bool
|
success: bool
|
||||||
error: Optional[Exception] = None
|
error: Exception | None = None
|
||||||
actual_outputs: Optional[dict[str, Any]] = None
|
actual_outputs: dict[str, Any] | None = None
|
||||||
execution_time: float = 0.0
|
execution_time: float = 0.0
|
||||||
event_sequence_match: Optional[bool] = None
|
event_sequence_match: bool | None = None
|
||||||
event_mismatch_details: Optional[str] = None
|
event_mismatch_details: str | None = None
|
||||||
events: list[GraphEngineEvent] = field(default_factory=list)
|
events: list[GraphEngineEvent] = field(default_factory=list)
|
||||||
retry_attempts: int = 0
|
retry_attempts: int = 0
|
||||||
validation_details: Optional[str] = None
|
validation_details: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -116,7 +114,7 @@ class TestSuiteResult:
|
|||||||
class WorkflowRunner:
|
class WorkflowRunner:
|
||||||
"""Core workflow execution engine for tests."""
|
"""Core workflow execution engine for tests."""
|
||||||
|
|
||||||
def __init__(self, fixtures_dir: Optional[Path] = None):
|
def __init__(self, fixtures_dir: Path | None = None):
|
||||||
"""Initialize the workflow runner."""
|
"""Initialize the workflow runner."""
|
||||||
if fixtures_dir is None:
|
if fixtures_dir is None:
|
||||||
# Use the new central fixtures location
|
# Use the new central fixtures location
|
||||||
@ -147,9 +145,9 @@ class WorkflowRunner:
|
|||||||
self,
|
self,
|
||||||
fixture_data: dict[str, Any],
|
fixture_data: dict[str, Any],
|
||||||
query: str = "",
|
query: str = "",
|
||||||
inputs: Optional[dict[str, Any]] = None,
|
inputs: dict[str, Any] | None = None,
|
||||||
use_mock_factory: bool = False,
|
use_mock_factory: bool = False,
|
||||||
mock_config: Optional[MockConfig] = None,
|
mock_config: MockConfig | None = None,
|
||||||
) -> tuple[Graph, GraphRuntimeState]:
|
) -> tuple[Graph, GraphRuntimeState]:
|
||||||
"""Create a Graph instance from fixture data."""
|
"""Create a Graph instance from fixture data."""
|
||||||
workflow_config = fixture_data.get("workflow", {})
|
workflow_config = fixture_data.get("workflow", {})
|
||||||
@ -240,7 +238,7 @@ class TableTestRunner:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fixtures_dir: Optional[Path] = None,
|
fixtures_dir: Path | None = None,
|
||||||
max_workers: int = 4,
|
max_workers: int = 4,
|
||||||
enable_logging: bool = False,
|
enable_logging: bool = False,
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
@ -373,23 +371,11 @@ class TableTestRunner:
|
|||||||
mock_config=test_case.mock_config,
|
mock_config=test_case.mock_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_config = fixture_data.get("workflow", {})
|
|
||||||
graph_config = workflow_config.get("graph", {})
|
|
||||||
|
|
||||||
# Create and run the engine with configured worker settings
|
# Create and run the engine with configured worker settings
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER, # Use DEBUGGER to avoid conversation_id requirement
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=500,
|
|
||||||
max_execution_time=int(test_case.timeout),
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
min_workers=self.graph_engine_min_workers,
|
min_workers=self.graph_engine_min_workers,
|
||||||
max_workers=self.graph_engine_max_workers,
|
max_workers=self.graph_engine_max_workers,
|
||||||
@ -469,8 +455,8 @@ class TableTestRunner:
|
|||||||
self,
|
self,
|
||||||
expected_outputs: dict[str, Any],
|
expected_outputs: dict[str, Any],
|
||||||
actual_outputs: dict[str, Any],
|
actual_outputs: dict[str, Any],
|
||||||
custom_validator: Optional[Callable[[dict[str, Any]], bool]] = None,
|
custom_validator: Callable[[dict[str, Any]], bool] | None = None,
|
||||||
) -> tuple[bool, Optional[str]]:
|
) -> tuple[bool, str | None]:
|
||||||
"""
|
"""
|
||||||
Validate actual outputs against expected outputs.
|
Validate actual outputs against expected outputs.
|
||||||
|
|
||||||
@ -519,7 +505,7 @@ class TableTestRunner:
|
|||||||
|
|
||||||
def _validate_event_sequence(
|
def _validate_event_sequence(
|
||||||
self, expected_sequence: list[type[GraphEngineEvent]], actual_events: list[GraphEngineEvent]
|
self, expected_sequence: list[type[GraphEngineEvent]], actual_events: list[GraphEngineEvent]
|
||||||
) -> tuple[bool, Optional[str]]:
|
) -> tuple[bool, str | None]:
|
||||||
"""
|
"""
|
||||||
Validate that actual events match the expected event sequence.
|
Validate that actual events match the expected event sequence.
|
||||||
|
|
||||||
@ -551,7 +537,7 @@ class TableTestRunner:
|
|||||||
self,
|
self,
|
||||||
test_cases: list[WorkflowTestCase],
|
test_cases: list[WorkflowTestCase],
|
||||||
parallel: bool = False,
|
parallel: bool = False,
|
||||||
tags_filter: Optional[list[str]] = None,
|
tags_filter: list[str] | None = None,
|
||||||
fail_fast: bool = False,
|
fail_fast: bool = False,
|
||||||
) -> TestSuiteResult:
|
) -> TestSuiteResult:
|
||||||
"""
|
"""
|
||||||
@ -715,4 +701,4 @@ def _load_fixture(fixture_path: Path, fixture_name: str) -> dict[str, Any]:
|
|||||||
if not fixture_path.exists():
|
if not fixture_path.exists():
|
||||||
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
|
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
|
||||||
|
|
||||||
return load_yaml_file(str(fixture_path), ignore_error=False)
|
return _load_yaml_file(file_path=str(fixture_path))
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.graph_engine import GraphEngine
|
from core.workflow.graph_engine import GraphEngine
|
||||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
)
|
)
|
||||||
from models.enums import UserFrom
|
|
||||||
|
|
||||||
from .test_table_runner import TableTestRunner
|
from .test_table_runner import TableTestRunner
|
||||||
|
|
||||||
@ -23,23 +21,11 @@ def test_tool_in_chatflow():
|
|||||||
use_mock_factory=True,
|
use_mock_factory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_config = fixture_data.get("workflow", {})
|
|
||||||
graph_config = workflow_config.get("graph", {})
|
|
||||||
|
|
||||||
# Create and run the engine
|
# Create and run the engine
|
||||||
engine = GraphEngine(
|
engine = GraphEngine(
|
||||||
tenant_id="test_tenant",
|
|
||||||
app_id="test_app",
|
|
||||||
workflow_id="test_workflow",
|
workflow_id="test_workflow",
|
||||||
user_id="test_user",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
|
||||||
call_depth=0,
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=graph_config,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
max_execution_steps=500,
|
|
||||||
max_execution_time=30,
|
|
||||||
command_channel=InMemoryChannel(),
|
command_channel=InMemoryChannel(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user