fix: type errors

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-09-10 21:48:05 +08:00
parent 00a1af8506
commit b4c1766932
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
55 changed files with 305 additions and 289 deletions

View File

@ -1,9 +1,33 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import TypedDict
from pydantic import BaseModel
class CodeNodeProvider(BaseModel):
class VariableConfig(TypedDict):
variable: str
value_selector: Sequence[str | int]
class OutputConfig(TypedDict):
type: str
children: None
class CodeConfig(TypedDict):
variables: Sequence[VariableConfig]
code_language: str
code: str
outputs: Mapping[str, OutputConfig]
class DefaultConfig(TypedDict):
type: str
config: CodeConfig
class CodeNodeProvider(BaseModel, ABC):
@staticmethod
@abstractmethod
def get_language() -> str:
@ -22,11 +46,14 @@ class CodeNodeProvider(BaseModel):
pass
@classmethod
def get_default_config(cls):
def get_default_config(cls) -> DefaultConfig:
return {
"type": "code",
"config": {
"variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}],
"variables": [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
],
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": {"result": {"type": "string", "children": None}},

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel
@ -7,4 +5,4 @@ class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: Optional[str] = None
icon: str | None = None

View File

@ -1,5 +1,5 @@
import hashlib
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel
@ -10,10 +10,10 @@ class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
branch_identify: str | None = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
conditions: list[Condition] | None = None
"""conditions to run the node, required when type is condition"""
@property

View File

@ -7,7 +7,7 @@ implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -28,7 +28,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any] = Field(...)
inputs: Mapping[str, Any] = Field(...)
outputs: Optional[Mapping[str, Any]] = None
outputs: Mapping[str, Any] | None = None
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
error_message: str = Field(default="")
@ -37,7 +37,7 @@ class WorkflowExecution(BaseModel):
exceptions_count: int = Field(default=0)
started_at: datetime = Field(...)
finished_at: Optional[datetime] = None
finished_at: datetime | None = None
@property
def elapsed_time(self) -> float:

View File

@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -39,41 +39,41 @@ class WorkflowNodeExecution(BaseModel):
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
# In most scenarios, `id` should be used as the primary identifier.
node_execution_id: Optional[str] = None
node_execution_id: str | None = None
workflow_id: str # ID of the workflow this node belongs to
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging)
# --------- Core identification fields ends ---------
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
predecessor_node_id: str | None = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node
# Execution data
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
inputs: Mapping[str, Any] | None = None # Input variables used by this node
process_data: Mapping[str, Any] | None = None # Intermediate processing data
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
# Execution state
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
error: Optional[str] = None # Error message if execution failed
error: str | None = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
finished_at: Optional[datetime] = None # When execution completed
finished_at: datetime | None = None # When execution completed
def update_from_mapping(
self,
inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
inputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
):
"""
Update the model from mappings.

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -14,4 +14,4 @@ class NodeRunAgentLogEvent(GraphAgentNodeEventBase):
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
metadata: Mapping[str, Any] | None = Field(default=None, description="metadata")

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -10,7 +10,7 @@ class GraphRunStartedEvent(BaseGraphEvent):
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None
outputs: dict[str, Any] | None = None
class GraphRunFailedEvent(BaseGraphEvent):
@ -20,11 +20,11 @@ class GraphRunFailedEvent(BaseGraphEvent):
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: Optional[dict[str, Any]] = None
outputs: dict[str, Any] | None = None
class GraphRunAbortedEvent(BaseGraphEvent):
"""Event emitted when a graph run is aborted by user command."""
reason: Optional[str] = Field(default=None, description="reason for abort")
outputs: Optional[dict[str, Any]] = Field(default=None, description="partial outputs if any")
reason: str | None = Field(default=None, description="reason for abort")
outputs: dict[str, Any] | None = Field(default=None, description="partial outputs if any")

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -10,31 +10,31 @@ from .base import GraphNodeEventBase
class NodeRunIterationStartedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
inputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
class NodeRunIterationNextEvent(GraphNodeEventBase):
node_title: str
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = None
pre_iteration_output: Any = None
class NodeRunIterationSucceededEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
class NodeRunIterationFailedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -10,31 +10,31 @@ from .base import GraphNodeEventBase
class NodeRunLoopStartedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
inputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
class NodeRunLoopNextEvent(GraphNodeEventBase):
node_title: str
index: int = Field(..., description="index")
pre_loop_output: Optional[Any] = None
pre_loop_output: Any = None
class NodeRunLoopSucceededEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
class NodeRunLoopFailedEvent(GraphNodeEventBase):
node_title: str
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence
from datetime import datetime
from typing import Optional
from pydantic import Field
@ -12,9 +11,9 @@ from .base import GraphNodeEventBase
class NodeRunStartedEvent(GraphNodeEventBase):
node_title: str
predecessor_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
agent_strategy: Optional[AgentNodeStrategyInit] = None
predecessor_node_id: str | None = None
parallel_mode_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
start_at: datetime = Field(..., description="node start time")
# FIXME(-LAN-): only for ToolNode

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -14,5 +14,5 @@ class AgentLogEvent(NodeEventBase):
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
metadata: Mapping[str, Any] | None = Field(default=None, description="metadata")
node_id: str = Field(..., description="node id")

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -9,28 +9,28 @@ from .base import NodeEventBase
class IterationStartedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
inputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
class IterationNextEvent(NodeEventBase):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = None
pre_iteration_output: Any = None
class IterationSucceededEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
class IterationFailedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -9,28 +9,28 @@ from .base import NodeEventBase
class LoopStartedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
inputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
class LoopNextEvent(NodeEventBase):
index: int = Field(..., description="index")
pre_loop_output: Optional[Any] = None
pre_loop_output: Any = None
class LoopSucceededEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
class LoopFailedEvent(NodeEventBase):
start_at: datetime = Field(..., description="start at")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
@ -77,7 +77,7 @@ class AgentNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -86,7 +86,7 @@ class AgentNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -414,7 +414,7 @@ class AgentNode(Node):
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]

View File

@ -1,6 +1,3 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
@ -12,7 +9,7 @@ class AgentNodeError(Exception):
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
def __init__(self, message: str, strategy_name: str | None = None, provider_name: str | None = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
@ -21,7 +18,7 @@ class AgentStrategyError(AgentNodeError):
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
def __init__(self, strategy_name: str, provider_name: str | None = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
@ -33,7 +30,7 @@ class AgentStrategyNotFoundError(AgentStrategyError):
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -41,7 +38,7 @@ class AgentInvocationError(AgentNodeError):
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
def __init__(self, message: str, parameter_name: str | None = None):
self.parameter_name = parameter_name
super().__init__(message)
@ -49,7 +46,7 @@ class AgentParameterError(AgentNodeError):
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
def __init__(self, message: str, variable_name: str | None = None):
self.variable_name = variable_name
super().__init__(message)
@ -71,7 +68,7 @@ class AgentInputTypeError(AgentNodeError):
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
def __init__(self, message: str, file_id: str | None = None):
self.file_id = file_id
super().__init__(message)
@ -86,7 +83,7 @@ class ToolFileNotFoundError(ToolFileError):
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
def __init__(self, message: str, original_error: Exception | None = None):
self.original_error = original_error
super().__init__(message)
@ -94,7 +91,7 @@ class AgentMessageTransformError(AgentNodeError):
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
def __init__(self, message: str, model_name: str | None = None, provider: str | None = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
@ -103,7 +100,7 @@ class AgentModelError(AgentNodeError):
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
def __init__(self, message: str, conversation_id: str | None = None):
self.conversation_id = conversation_id
super().__init__(message)
@ -114,9 +111,9 @@ class AgentVariableTypeError(AgentNodeError):
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
variable_name: str | None = None,
expected_type: str | None = None,
actual_type: str | None = None,
):
self.variable_name = variable_name
self.expected_type = expected_type

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
@ -20,7 +20,7 @@ class AnswerNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -29,7 +29,7 @@ class AnswerNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -2,7 +2,7 @@ import json
from abc import ABC
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Optional, Union
from typing import Any, Union
from pydantic import BaseModel, model_validator
@ -128,10 +128,10 @@ class DefaultValue(BaseModel):
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
desc: str | None = None
version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
@ -142,7 +142,7 @@ class BaseNodeData(ABC, BaseModel):
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseIterationState(BaseModel):
@ -157,7 +157,7 @@ class BaseIterationState(BaseModel):
class BaseLoopNodeData(BaseNodeData):
start_node_id: Optional[str] = None
start_node_id: str | None = None
class BaseLoopState(BaseModel):

View File

@ -145,11 +145,10 @@ class Node:
for event in result:
# NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
event = self._dispatch(event)
if not event.in_iteration_id and not event.in_loop_id:
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
yield event
yield event
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, Optional
from typing import Any, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -30,7 +30,7 @@ class CodeNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -39,7 +39,7 @@ class CodeNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -49,7 +49,7 @@ class CodeNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.
@ -57,7 +57,7 @@ class CodeNode(Node):
"""
code_language = CodeLanguage.PYTHON3
if filters:
code_language = filters.get("code_language", CodeLanguage.PYTHON3)
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
@ -154,7 +154,7 @@ class CodeNode(Node):
def _transform_result(
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
output_schema: dict[str, CodeNodeData.Output] | None,
prefix: str = "",
depth: int = 1,
):

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Optional
from typing import Annotated, Literal, Self
from pydantic import AfterValidator, BaseModel
@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: Optional[dict[str, "CodeNodeData.Output"]] = None
children: dict[str, Self] | None = None
class Dependency(BaseModel):
name: str
@ -44,4 +44,4 @@ class CodeNodeData(BaseNodeData):
code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]
code: str
outputs: dict[str, Output]
dependencies: Optional[list[Dependency]] = None
dependencies: list[Dependency] | None = None

View File

@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
import chardet
import docx
@ -49,7 +49,7 @@ class DocumentExtractorNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -58,7 +58,7 @@ class DocumentExtractorNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -18,7 +18,7 @@ class EndNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +27,7 @@ class EndNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,7 +1,7 @@
import mimetypes
from collections.abc import Sequence
from email.message import Message
from typing import Any, Literal, Optional
from typing import Any, Literal
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -18,7 +18,7 @@ class HttpRequestNodeAuthorizationConfig(BaseModel):
class HttpRequestNodeAuthorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[HttpRequestNodeAuthorizationConfig] = None
config: HttpRequestNodeAuthorizationConfig | None = None
@field_validator("config", mode="before")
@classmethod
@ -88,9 +88,9 @@ class HttpRequestNodeData(BaseNodeData):
authorization: HttpRequestNodeAuthorization
headers: str
params: str
body: Optional[HttpRequestNodeBody] = None
timeout: Optional[HttpRequestNodeTimeout] = None
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
body: HttpRequestNodeBody | None = None
timeout: HttpRequestNodeTimeout | None = None
ssl_verify: bool | None = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
class Response:
@ -183,7 +183,7 @@ class Response:
return f"{(self.size / 1024 / 1024):.2f} MB"
@property
def parsed_content_disposition(self) -> Optional[Message]:
def parsed_content_disposition(self) -> Message | None:
content_disposition = self.headers.get("content-disposition", "")
if content_disposition:
msg = Message()

View File

@ -1,7 +1,7 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from configs import dify_config
from core.file import File, FileTransferMethod
@ -39,7 +39,7 @@ class HttpRequestNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -48,7 +48,7 @@ class HttpRequestNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -58,7 +58,7 @@ class HttpRequestNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "http-request",
"config": {

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -20,7 +20,7 @@ class IfElseNodeData(BaseNodeData):
logical_operator: Literal["and", "or"]
conditions: list[Condition]
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
cases: Optional[list[Case]] = None
cases: list[Case] | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
from typing_extensions import deprecated
@ -22,7 +22,7 @@ class IfElseNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -31,7 +31,7 @@ class IfElseNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import Field
@ -17,7 +17,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
parent_loop_id: Optional[str] = None # redundant field, not used currently
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
@ -39,7 +39,7 @@ class IterationState(BaseIterationState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any = None
class MetaData(BaseIterationState.MetaData):
"""
@ -48,7 +48,7 @@ class IterationState(BaseIterationState):
iterator_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@ -56,7 +56,7 @@ class IterationState(BaseIterationState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@ -82,7 +82,7 @@ class IterationNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: dict[str, object] | None = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "iteration",
"config": {

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -20,7 +20,7 @@ class IterationStartNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -29,7 +29,7 @@ class IterationStartNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, Field
@ -49,11 +49,11 @@ class MultipleRetrievalConfig(BaseModel):
"""
top_k: int
score_threshold: Optional[float] = None
score_threshold: float | None = None
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class SingleRetrievalConfig(BaseModel):
@ -104,8 +104,8 @@ class MetadataFilteringCondition(BaseModel):
Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class KnowledgeRetrievalNodeData(BaseNodeData):
@ -117,11 +117,11 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
query_variable_selector: list[str]
dataset_ids: list[str]
retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
multiple_retrieval_config: MultipleRetrievalConfig | None = None
single_retrieval_config: SingleRetrievalConfig | None = None
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -4,7 +4,7 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
@ -119,7 +119,7 @@ class KnowledgeRetrievalNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -128,7 +128,7 @@ class KnowledgeRetrievalNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -410,7 +410,7 @@ class KnowledgeRetrievalNode(Node):
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
@ -568,7 +568,7 @@ class KnowledgeRetrievalNode(Node):
return automatic_metadata_filters
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list[Any]
self, sequence: int, condition: str, metadata_name: str, value: Any, filters: list[Any]
) -> list[Any]:
if value is None and condition not in ("empty", "not empty"):
return filters

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Optional, TypeAlias, TypeVar
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -43,7 +43,7 @@ class ListOperatorNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -52,7 +52,7 @@ class ListOperatorNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -18,7 +18,7 @@ class ModelConfig(BaseModel):
class ContextConfig(BaseModel):
enabled: bool
variable_selector: Optional[list[str]] = None
variable_selector: list[str] | None = None
class VisionConfigOptions(BaseModel):
@ -51,18 +51,18 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: Optional[str] = None
jinja2_text: str | None = None
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Optional, cast
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
@ -86,8 +86,8 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None

View File

@ -4,7 +4,7 @@ import json
import logging
import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -139,7 +139,7 @@ class LLMNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -148,7 +148,7 @@ class LLMNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -354,10 +354,10 @@ class LLMNode(Node):
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Optional[Mapping[str, Any]] = None,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
@ -716,7 +716,7 @@ class LLMNode(Node):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
@ -959,7 +959,7 @@ class LLMNode(Node):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "llm",
"config": {
@ -987,7 +987,7 @@ class LLMNode(Node):
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
@ -1175,7 +1175,7 @@ class LLMNode(Node):
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
@ -1281,7 +1281,7 @@ def _handle_memory_completion_mode(
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
@ -41,7 +41,7 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@ -74,7 +74,7 @@ class LoopState(BaseLoopState):
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Optional[Any] = None
current_output: Any = None
class MetaData(BaseLoopState.MetaData):
"""
@ -83,7 +83,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Optional[Any]:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@ -91,7 +91,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -20,7 +20,7 @@ class LoopEndNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -29,7 +29,7 @@ class LoopEndNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from core.variables import Segment, SegmentType
from core.workflow.enums import (
@ -51,7 +51,7 @@ class LoopNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -60,7 +60,7 @@ class LoopNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
@ -20,7 +20,7 @@ class LoopStartNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -29,7 +29,7 @@ class LoopStartNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal, Optional
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
@ -48,7 +48,7 @@ class ParameterConfig(BaseModel):
name: str
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: Optional[list[str]] = None
options: list[str] | None = None
description: str
required: bool
@ -86,8 +86,8 @@ class ParameterExtractorNodeData(BaseNodeData):
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@ -3,7 +3,7 @@ import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
@ -96,7 +96,7 @@ class ParameterExtractorNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -105,7 +105,7 @@ class ParameterExtractorNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -114,11 +114,11 @@ class ParameterExtractorNode(Node):
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"model": {
"prompt_templates": {
@ -293,7 +293,7 @@ class ParameterExtractorNode(Node):
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
@ -323,9 +323,9 @@ class ParameterExtractorNode(Node):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
@ -405,9 +405,9 @@ class ParameterExtractorNode(Node):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
@ -443,9 +443,9 @@ class ParameterExtractorNode(Node):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate completion prompt.
@ -477,9 +477,9 @@ class ParameterExtractorNode(Node):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: Optional[ImagePromptMessageContent.DETAIL] = None,
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
Generate chat prompt.
@ -651,7 +651,7 @@ class ParameterExtractorNode(Node):
return transformed_result
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
def _extract_complete_json_response(self, result: str) -> dict | None:
"""
Extract complete json response.
"""
@ -666,7 +666,7 @@ class ParameterExtractorNode(Node):
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
"""
Extract json from tool call.
"""
@ -705,7 +705,7 @@ class ParameterExtractorNode(Node):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -732,7 +732,7 @@ class ParameterExtractorNode(Node):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -768,7 +768,7 @@ class ParameterExtractorNode(Node):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@ -16,8 +14,8 @@ class QuestionClassifierNodeData(BaseNodeData):
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
instruction: str | None = None
memory: MemoryConfig | None = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -80,7 +80,7 @@ class QuestionClassifierNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -89,7 +89,7 @@ class QuestionClassifierNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -271,7 +271,7 @@ class QuestionClassifierNode(Node):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters (not used in this implementation).
@ -285,7 +285,7 @@ class QuestionClassifierNode(Node):
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
@ -328,7 +328,7 @@ class QuestionClassifierNode(Node):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: Optional[TokenBufferMemory],
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
@ -18,7 +18,7 @@ class StartNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -27,7 +27,7 @@ class StartNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
@ -20,7 +20,7 @@ class TemplateTransformNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -29,7 +29,7 @@ class TemplateTransformNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -39,7 +39,7 @@ class TemplateTransformNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -470,7 +470,7 @@ class ToolNode(Node):
return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -479,7 +479,7 @@ class ToolNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel
from core.variables.types import SegmentType
@ -33,4 +31,4 @@ class VariableAssignerNodeData(BaseNodeData):
type: str = "variable-assigner"
output_type: str
variables: list[list[str]]
advanced_settings: Optional[AdvancedSettings] = None
advanced_settings: AdvancedSettings | None = None

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.variables.segments import Segment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
@ -17,7 +17,7 @@ class VariableAggregatorNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -26,7 +26,7 @@ class VariableAggregatorNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias
from core.variables import SegmentType, Variable
from core.variables.segments import BooleanSegment
@ -33,7 +33,7 @@ class VariableAssignerNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -42,7 +42,7 @@ class VariableAssignerNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -18,7 +18,7 @@ class VariableOperationItem(BaseModel):
# 2. For VARIABLE input_type: Initially contains the selector of the source variable.
# 3. During the variable updating procedure: The `value` field is reassigned to hold
# the resolved actual value that will be applied to the target variable.
value: Any | None = None
value: Any = None
class VariableAssignerNodeData(BaseNodeData):

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
@ -60,7 +60,7 @@ class VariableAssignerNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -69,7 +69,7 @@ class VariableAssignerNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from typing import Literal, Protocol
from core.workflow.entities import WorkflowNodeExecution
@ -10,7 +10,7 @@ class OrderConfig:
"""Configuration for ordering NodeExecution instances."""
order_by: list[str]
order_direction: Optional[Literal["asc", "desc"]] = None
order_direction: Literal["asc", "desc"] | None = None
class WorkflowNodeExecutionRepository(Protocol):
@ -42,7 +42,7 @@ class WorkflowNodeExecutionRepository(Protocol):
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
order_config: OrderConfig | None = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, Union
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@ -85,9 +85,9 @@ class WorkflowCycleManager:
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
external_trace_id: Optional[str] = None,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
@ -112,9 +112,9 @@ class WorkflowCycleManager:
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
external_trace_id: Optional[str] = None,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
@ -140,10 +140,10 @@ class WorkflowCycleManager:
total_steps: int,
status: WorkflowExecutionStatus,
error_message: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
exceptions_count: int = 0,
external_trace_id: Optional[str] = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
@ -295,9 +295,9 @@ class WorkflowCycleManager:
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
error_message: Optional[str] = None,
error_message: str | None = None,
exceptions_count: int = 0,
finished_at: Optional[datetime] = None,
finished_at: datetime | None = None,
):
"""Update workflow execution with completion data."""
execution.status = status
@ -311,10 +311,10 @@ class WorkflowCycleManager:
def _add_trace_task_if_needed(
self,
trace_manager: Optional[TraceQueueManager],
trace_manager: TraceQueueManager | None,
workflow_execution: WorkflowExecution,
conversation_id: Optional[str],
external_trace_id: Optional[str],
conversation_id: str | None,
external_trace_id: str | None,
):
"""Add trace task if trace manager is provided."""
if trace_manager:
@ -401,7 +401,7 @@ class WorkflowCycleManager:
QueueNodeExceptionEvent,
],
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
error: str | None = None,
handle_special_values: bool = False,
):
"""Update node execution with completion data."""

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import Any
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@ -43,7 +43,7 @@ class WorkflowEntry:
call_depth: int,
variable_pool: VariablePool,
graph_runtime_state: GraphRuntimeState,
command_channel: Optional[CommandChannel] = None,
command_channel: CommandChannel | None = None,
) -> None:
"""
Init workflow entry
@ -341,7 +341,7 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain

View File

@ -543,12 +543,12 @@ class WorkflowService:
# This will prevent validation errors from breaking the workflow
return []
def get_default_block_configs(self) -> list[dict]:
def get_default_block_configs(self) -> Sequence[Mapping[str, object]]:
"""
Get default block configs
"""
# return default block config
default_block_configs = []
default_block_configs: list[Mapping[str, object]] = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()