chore: use from __future__ import annotations (#30254)

Co-authored-by: Dev <dev@Devs-MacBook-Pro-4.local>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Sara Rasool 2026-01-06 19:57:20 +05:00 committed by GitHub
parent 0294555893
commit 4f0fb6df2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 253 additions and 163 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
@ -99,7 +101,7 @@ class AccountPasswordPayload(BaseModel):
repeat_new_password: str
@model_validator(mode="after")
def check_passwords_match(self) -> "AccountPasswordPayload":
def check_passwords_match(self) -> AccountPasswordPayload:
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from configs import dify_config
@ -30,7 +32,7 @@ class DatasourcePlugin(ABC):
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> DatasourcePlugin:
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum
from enum import StrEnum
from typing import Any
@ -31,7 +33,7 @@ class DatasourceProviderType(enum.StrEnum):
ONLINE_DRIVE = "online_drive"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
def value_of(cls, value: str) -> DatasourceProviderType:
"""
Get value of given mode.
@ -81,7 +83,7 @@ class DatasourceParameter(PluginParameter):
typ: DatasourceParameterType,
required: bool,
options: list[str] | None = None,
) -> "DatasourceParameter":
) -> DatasourceParameter:
"""
get a simple datasource parameter
@ -187,14 +189,14 @@ class DatasourceInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
def empty(cls) -> DatasourceInvokeMeta:
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
def error_instance(cls, error: str) -> DatasourceInvokeMeta:
"""
Get an instance of DatasourceInvokeMeta with error
"""

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from datetime import datetime
from enum import StrEnum
@ -75,7 +77,7 @@ class MCPProviderEntity(BaseModel):
updated_at: datetime
@classmethod
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
def from_db_model(cls, db_provider: MCPToolProvider) -> MCPProviderEntity:
"""Create entity from database model with decryption"""
return cls(

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Union
@ -178,7 +180,7 @@ class BasicProviderConfig(BaseModel):
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
def value_of(cls, value: str) -> ProviderConfig.Type:
"""
Get value of given mode.

View File

@ -68,13 +68,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT
]""",
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
@ -17,7 +19,7 @@ class PromptMessageRole(StrEnum):
TOOL = auto()
@classmethod
def value_of(cls, value: str) -> "PromptMessageRole":
def value_of(cls, value: str) -> PromptMessageRole:
"""
Get value of given mode.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from decimal import Decimal
from enum import StrEnum, auto
from typing import Any
@ -20,7 +22,7 @@ class ModelType(StrEnum):
TTS = auto()
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
def value_of(cls, origin_model_type: str) -> ModelType:
"""
Get model type from origin model type.
@ -103,7 +105,7 @@ class DefaultParameterName(StrEnum):
JSON_SCHEMA = auto()
@classmethod
def value_of(cls, value: Any) -> "DefaultParameterName":
def value_of(cls, value: Any) -> DefaultParameterName:
"""
Get parameter name from value.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import hashlib
import logging
from collections.abc import Sequence
@ -38,7 +40,7 @@ class ModelProviderFactory:
plugin_providers = self.get_plugin_model_providers()
return [provider.declaration for provider in plugin_providers]
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
"""
Get all plugin model providers
:return: list of plugin model providers
@ -76,7 +78,7 @@ class ModelProviderFactory:
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
return plugin_model_provider_entity.declaration
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
"""
Get plugin model provider
:param provider: provider name

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
@ -242,7 +244,7 @@ class CredentialType(enum.StrEnum):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
def of(cls, credential_type: str) -> CredentialType:
type_name = credential_type.lower()
if type_name in {"api-key", "api_key"}:
return cls.API_KEY

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import contextlib
import json
import logging
@ -6,7 +8,7 @@ import re
import threading
import time
import uuid
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
@ -76,7 +78,7 @@ class ClickzettaConnectionPool:
Manages connection reuse across ClickzettaVector instances.
"""
_instance: Optional["ClickzettaConnectionPool"] = None
_instance: ClickzettaConnectionPool | None = None
_lock = threading.Lock()
def __init__(self):
@ -89,7 +91,7 @@ class ClickzettaConnectionPool:
self._start_cleanup_thread()
@classmethod
def get_instance(cls) -> "ClickzettaConnectionPool":
def get_instance(cls) -> ClickzettaConnectionPool:
"""Get singleton instance of connection pool."""
if cls._instance is None:
with cls._lock:
@ -104,7 +106,7 @@ class ClickzettaConnectionPool:
f"{config.workspace}:{config.vcluster}:{config.schema_name}"
)
def _create_connection(self, config: ClickzettaConfig) -> "Connection":
def _create_connection(self, config: ClickzettaConfig) -> Connection:
"""Create a new ClickZetta connection."""
max_retries = 3
retry_delay = 1.0
@ -134,7 +136,7 @@ class ClickzettaConnectionPool:
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts")
def _configure_connection(self, connection: "Connection"):
def _configure_connection(self, connection: Connection):
"""Configure connection session settings."""
try:
with connection.cursor() as cursor:
@ -181,7 +183,7 @@ class ClickzettaConnectionPool:
except Exception:
logger.exception("Failed to configure connection, continuing with defaults")
def _is_connection_valid(self, connection: "Connection") -> bool:
def _is_connection_valid(self, connection: Connection) -> bool:
"""Check if connection is still valid."""
try:
with connection.cursor() as cursor:
@ -190,7 +192,7 @@ class ClickzettaConnectionPool:
except Exception:
return False
def get_connection(self, config: ClickzettaConfig) -> "Connection":
def get_connection(self, config: ClickzettaConfig) -> Connection:
"""Get a connection from the pool or create a new one."""
config_key = self._get_config_key(config)
@ -221,7 +223,7 @@ class ClickzettaConnectionPool:
# No valid connection found, create new one
return self._create_connection(config)
def return_connection(self, config: ClickzettaConfig, connection: "Connection"):
def return_connection(self, config: ClickzettaConfig, connection: Connection):
"""Return a connection to the pool."""
config_key = self._get_config_key(config)
@ -315,22 +317,22 @@ class ClickzettaVector(BaseVector):
self._connection_pool = ClickzettaConnectionPool.get_instance()
self._init_write_queue()
def _get_connection(self) -> "Connection":
def _get_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._connection_pool.get_connection(self._config)
def _return_connection(self, connection: "Connection"):
def _return_connection(self, connection: Connection):
"""Return a connection to the pool."""
self._connection_pool.return_connection(self._config, connection)
class ConnectionContext:
"""Context manager for borrowing and returning connections."""
def __init__(self, vector_instance: "ClickzettaVector"):
def __init__(self, vector_instance: ClickzettaVector):
self.vector = vector_instance
self.connection: Connection | None = None
def __enter__(self) -> "Connection":
def __enter__(self) -> Connection:
self.connection = self.vector._get_connection()
return self.connection
@ -338,7 +340,7 @@ class ClickzettaVector(BaseVector):
if self.connection:
self.vector._return_connection(self.connection)
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext":
def get_connection_context(self) -> ClickzettaVector.ConnectionContext:
"""Get a connection context manager."""
return self.ConnectionContext(self)
@ -437,7 +439,7 @@ class ClickzettaVector(BaseVector):
"""Return the vector database type."""
return "clickzetta"
def _ensure_connection(self) -> "Connection":
def _ensure_connection(self) -> Connection:
"""Get a connection from the pool."""
return self._get_connection()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any
@ -22,7 +24,7 @@ class DatasetDocumentStore:
self._document_id = document_id
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
def from_dict(cls, config_dict: dict[str, Any]) -> DatasetDocumentStore:
return cls(**config_dict)
def to_dict(self) -> dict[str, Any]:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any
@ -16,7 +18,7 @@ class TaskWrapper(BaseModel):
return self.model_dump_json()
@classmethod
def deserialize(cls, serialized_data: str) -> "TaskWrapper":
def deserialize(cls, serialized_data: str) -> TaskWrapper:
return cls.model_validate_json(serialized_data)

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import json
import logging
import threading
from collections.abc import Mapping, MutableMapping
from pathlib import Path
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar
class SchemaRegistry:
@ -11,7 +13,7 @@ class SchemaRegistry:
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
_default_instance: ClassVar[SchemaRegistry | None] = None
_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self, base_dir: str):
@ -20,7 +22,7 @@ class SchemaRegistry:
self.metadata: MutableMapping[str, MutableMapping[str, Any]] = {}
@classmethod
def default_registry(cls) -> "SchemaRegistry":
def default_registry(cls) -> SchemaRegistry:
"""Returns the default schema registry for builtin schemas (thread-safe singleton)"""
if cls._default_instance is None:
with cls._lock:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
@ -24,7 +26,7 @@ class Tool(ABC):
self.entity = entity
self.runtime = runtime
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> Tool:
"""
fork a new tool with metadata
:return: the new tool
@ -166,7 +168,7 @@ class Tool(ABC):
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
)
def create_file_message(self, file: "File") -> ToolInvokeMessage:
def create_file_message(self, file: File) -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE,
message=ToolInvokeMessage.FileMessage(),

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
@ -24,7 +26,7 @@ class BuiltinTool(Tool):
super().__init__(**kwargs)
self.provider = provider
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> BuiltinTool:
"""
fork a new tool with metadata
:return: the new tool

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from pydantic import Field
from sqlalchemy import select
@ -32,7 +34,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = []
@classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> ApiToolProviderController:
credentials_schema = [
ProviderConfig(
name="auth_type",

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import contextlib
from collections.abc import Mapping
@ -55,7 +57,7 @@ class ToolProviderType(StrEnum):
MCP = auto()
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
def value_of(cls, value: str) -> ToolProviderType:
"""
Get value of given mode.
@ -79,7 +81,7 @@ class ApiProviderSchemaType(StrEnum):
OPENAI_ACTIONS = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderSchemaType":
def value_of(cls, value: str) -> ApiProviderSchemaType:
"""
Get value of given mode.
@ -102,7 +104,7 @@ class ApiProviderAuthType(StrEnum):
API_KEY_QUERY = auto()
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
def value_of(cls, value: str) -> ApiProviderAuthType:
"""
Get value of given mode.
@ -307,7 +309,7 @@ class ToolParameter(PluginParameter):
typ: ToolParameterType,
required: bool,
options: list[str] | None = None,
) -> "ToolParameter":
) -> ToolParameter:
"""
get a simple tool parameter
@ -429,14 +431,14 @@ class ToolInvokeMeta(BaseModel):
tool_config: dict | None = None
@classmethod
def empty(cls) -> "ToolInvokeMeta":
def empty(cls) -> ToolInvokeMeta:
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "ToolInvokeMeta":
def error_instance(cls, error: str) -> ToolInvokeMeta:
"""
Get an instance of ToolInvokeMeta with error
"""

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import json
import logging
@ -118,7 +120,7 @@ class MCPTool(Tool):
for item in json_list:
yield self.create_json_message(item)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
return MCPTool(
entity=self.entity,
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Generator
from typing import Any
@ -46,7 +48,7 @@ class PluginTool(Tool):
message_id=message_id,
)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> PluginTool:
return PluginTool(
entity=self.entity,
runtime=runtime,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Mapping
from pydantic import Field
@ -47,7 +49,7 @@ class WorkflowToolProviderController(ToolProviderController):
self.provider_id = provider_id
@classmethod
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
with session_factory.create_session() as session, session.begin():
app = session.get(App, db_provider.app_id)
if not app:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@ -181,7 +183,7 @@ class WorkflowTool(Tool):
return found
return None
def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
def fork_tool_runtime(self, runtime: ToolRuntime) -> WorkflowTool:
"""
fork a new tool with metadata

View File

@ -1,6 +1,8 @@
from __future__ import annotations
from collections.abc import Mapping
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from core.file.models import File
@ -52,7 +54,7 @@ class SegmentType(StrEnum):
return self in _ARRAY_TYPES
@classmethod
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
def infer_segment_type(cls, value: Any) -> SegmentType | None:
"""
Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
@ -173,7 +175,7 @@ class SegmentType(StrEnum):
raise AssertionError("this statement should be unreachable.")
@staticmethod
def cast_value(value: Any, type_: "SegmentType"):
def cast_value(value: Any, type_: SegmentType):
# Cast Python's `bool` type to `int` when the runtime type requires
# an integer or number.
#
@ -193,7 +195,7 @@ class SegmentType(StrEnum):
return [int(i) for i in value]
return value
def exposed_type(self) -> "SegmentType":
def exposed_type(self) -> SegmentType:
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
@ -202,7 +204,7 @@ class SegmentType(StrEnum):
return SegmentType.NUMBER
return self
def element_type(self) -> "SegmentType | None":
def element_type(self) -> SegmentType | None:
"""Return the element type of the current segment type, or `None` if the element type is undefined.
Raises:
@ -217,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: "SegmentType"):
def get_zero_value(t: SegmentType):
# Lazy import to avoid circular dependency
from factories import variable_factory

View File

@ -5,6 +5,8 @@ Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
from typing import Any
@ -59,7 +61,7 @@ class WorkflowExecution(BaseModel):
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
) -> WorkflowExecution:
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -175,7 +177,7 @@ class Graph:
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
@ -197,7 +199,7 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@ -284,9 +286,9 @@ class Graph:
cls,
*,
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
node_factory: NodeFactory,
root_node_id: str | None = None,
) -> "Graph":
) -> Graph:
"""
Initialize graph
@ -383,7 +385,7 @@ class GraphBuilder:
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
def add_root(self, node: Node) -> GraphBuilder:
"""Register the root node. Must be called exactly once."""
if self._nodes:
@ -398,7 +400,7 @@ class GraphBuilder:
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
) -> GraphBuilder:
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
@ -419,7 +421,7 @@ class GraphBuilder:
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder:
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:

View File

@ -5,6 +5,8 @@ This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
from __future__ import annotations
import contextvars
import logging
import queue
@ -232,7 +234,7 @@ class GraphEngine:
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)

View File

@ -2,6 +2,8 @@
Factory for creating ReadyQueue instances from serialized state.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
@ -11,7 +13,7 @@ if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue:
"""
Create a ReadyQueue instance from a serialized state.

View File

@ -5,6 +5,8 @@ This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
@ -27,7 +29,7 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
def from_node(cls, node: Node) -> ResponseSession:
"""
Create a ResponseSession from an AnswerNode or EndNode.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
@ -167,7 +169,7 @@ class AgentNode(Node[AgentNodeData]):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: "PluginAgentStrategy",
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@ -328,7 +330,7 @@ class AgentNode(Node[AgentNodeData]):
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> "InvokeCredentials":
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
@ -442,9 +444,7 @@ class AgentNode(Node[AgentNodeData]):
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(
self, strategy: "PluginAgentStrategy", tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import importlib
import logging
import operator
@ -59,7 +61,7 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
node_type: ClassVar["NodeType"]
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@ -198,14 +200,14 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
self._graph_init_params = graph_init_params
self.id = id
@ -241,7 +243,7 @@ class Node(Generic[NodeDataT]):
return
@property
def graph_init_params(self) -> "GraphInitParams":
def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@property
@ -457,7 +459,7 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.

View File

@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
similar to SegmentGroup but focused on template representation without values.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
@ -58,7 +60,7 @@ class Template:
segments: list[TemplateSegmentUnion]
@classmethod
def from_answer_template(cls, template_str: str) -> "Template":
def from_answer_template(cls, template_str: str) -> Template:
"""Create a Template from an Answer node template string.
Example:
@ -107,7 +109,7 @@ class Template:
return cls(segments=segments)
@classmethod
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
"""Create a Template from an End node outputs configuration.
End nodes are treated as templates of concatenated variables with newlines.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import io
import json
@ -113,7 +115,7 @@ class LLMNode(Node[LLMNodeData]):
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_file_outputs: list[File]
_llm_file_saver: LLMFileSaver
@ -121,8 +123,8 @@ class LLMNode(Node[LLMNodeData]):
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@ -361,7 +363,7 @@ class LLMNode(Node[LLMNodeData]):
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@ -415,7 +417,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
@ -525,7 +527,7 @@ class LLMNode(Node[LLMNodeData]):
)
@staticmethod
def _image_file_to_markdown(file: "File", /):
def _image_file_to_markdown(file: File, /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@ -774,7 +776,7 @@ class LLMNode(Node[LLMNodeData]):
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@ -785,7 +787,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list["File"] | None = None,
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@ -1137,7 +1139,7 @@ class LLMNode(Node[LLMNodeData]):
*,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_latency: float | None = None,
) -> ModelInvokeCompletedEvent:
@ -1179,7 +1181,7 @@ class LLMNode(Node[LLMNodeData]):
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
) -> File:
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@ -1229,7 +1231,7 @@ class LLMNode(Node[LLMNodeData]):
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
file_outputs: list[File],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import abc
from collections.abc import Mapping
from typing import Any, Protocol
@ -23,7 +25,7 @@ class DraftVariableSaverFactory(Protocol):
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> "DraftVariableSaver":
) -> DraftVariableSaver:
pass

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
@ -267,6 +269,6 @@ class VariablePool(BaseModel):
self.add(selector, value)
@classmethod
def empty(cls) -> "VariablePool":
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from types import MappingProxyType
from typing import Any
@ -70,7 +72,7 @@ class SystemVariable(BaseModel):
return data
@classmethod
def empty(cls) -> "SystemVariable":
def empty(cls) -> SystemVariable:
return cls()
def to_dict(self) -> dict[SystemVariableKey, Any]:
@ -114,7 +116,7 @@ class SystemVariable(BaseModel):
d[SystemVariableKey.TIMESTAMP] = self.timestamp
return d
def as_view(self) -> "SystemVariableReadOnlyView":
def as_view(self) -> SystemVariableReadOnlyView:
return SystemVariableReadOnlyView(self)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import logging
import os
import threading
@ -33,7 +35,7 @@ class AliyunLogStore:
Ensures only one instance exists to prevent multiple PG connection pools.
"""
_instance: "AliyunLogStore | None" = None
_instance: AliyunLogStore | None = None
_initialized: bool = False
# Track delayed PG connection for newly created projects
@ -66,7 +68,7 @@ class AliyunLogStore:
"\t",
]
def __new__(cls) -> "AliyunLogStore":
def __new__(cls) -> AliyunLogStore:
"""Implement singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)

View File

@ -5,6 +5,8 @@ automatic cleanup, backup and restore.
Supports complete lifecycle management for knowledge base files.
"""
from __future__ import annotations
import json
import logging
import operator
@ -48,7 +50,7 @@ class FileMetadata:
return data
@classmethod
def from_dict(cls, data: dict) -> "FileMetadata":
def from_dict(cls, data: dict) -> FileMetadata:
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])

View File

@ -2,6 +2,8 @@
Broadcast channel for Pub/Sub messaging.
"""
from __future__ import annotations
import types
from abc import abstractmethod
from collections.abc import Iterator
@ -129,6 +131,6 @@ class BroadcastChannel(Protocol):
"""
@abstractmethod
def topic(self, topic: str) -> "Topic":
def topic(self, topic: str) -> Topic:
"""topic returns a `Topic` instance for the given topic name."""
...

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
@ -20,7 +22,7 @@ class BroadcastChannel:
):
self._client = redis_client
def topic(self, topic: str) -> "Topic":
def topic(self, topic: str) -> Topic:
return Topic(self._client, topic)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
@ -18,7 +20,7 @@ class ShardedRedisBroadcastChannel:
):
self._client = redis_client
def topic(self, topic: str) -> "ShardedTopic":
def topic(self, topic: str) -> ShardedTopic:
return ShardedTopic(self._client, topic)

View File

@ -6,6 +6,8 @@ in Dify. It follows Domain-Driven Design principles with proper type hints and
eliminates the need for repetitive language switching logic.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum, auto
from typing import Any, Protocol
@ -53,7 +55,7 @@ class EmailLanguage(StrEnum):
ZH_HANS = "zh-Hans"
@classmethod
def from_language_code(cls, language_code: str) -> "EmailLanguage":
def from_language_code(cls, language_code: str) -> EmailLanguage:
"""Convert a language code to EmailLanguage with fallback to English."""
if language_code == "zh-Hans":
return cls.ZH_HANS

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
import re
import uuid
@ -5,7 +7,7 @@ from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
import sqlalchemy as sa
@ -54,7 +56,7 @@ class AppMode(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> "AppMode":
def value_of(cls, value: str) -> AppMode:
"""
Get value of given mode.
@ -121,19 +123,19 @@ class App(Base):
return ""
@property
def site(self) -> Optional["Site"]:
def site(self) -> Site | None:
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
def app_model_config(self) -> Optional["AppModelConfig"]:
def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return None
@property
def workflow(self) -> Optional["Workflow"]:
def workflow(self) -> Workflow | None:
if self.workflow_id:
from .workflow import Workflow
@ -288,7 +290,7 @@ class App(Base):
return deleted_tools
@property
def tags(self) -> list["Tag"]:
def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
@ -1194,7 +1196,7 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self) -> list["MessageAgentThought"]:
def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
@ -1307,7 +1309,7 @@ class Message(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Message":
def from_dict(cls, data: dict[str, Any]) -> Message:
return cls(
id=data["id"],
app_id=data["app_id"],

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum, auto
from functools import cached_property
@ -19,7 +21,7 @@ class ProviderType(StrEnum):
SYSTEM = auto()
@staticmethod
def value_of(value: str) -> "ProviderType":
def value_of(value: str) -> ProviderType:
for member in ProviderType:
if member.value == value:
return member
@ -37,7 +39,7 @@ class ProviderQuotaType(StrEnum):
"""hosted trial quota"""
@staticmethod
def value_of(value: str) -> "ProviderQuotaType":
def value_of(value: str) -> ProviderQuotaType:
for member in ProviderQuotaType:
if member.value == value:
return member

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from datetime import datetime
from decimal import Decimal
@ -167,11 +169,11 @@ class ApiToolProvider(TypeBase):
)
@property
def schema_type(self) -> "ApiProviderSchemaType":
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def tools(self) -> list["ApiToolBundle"]:
def tools(self) -> list[ApiToolBundle]:
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
@property
@ -267,7 +269,7 @@ class WorkflowToolProvider(TypeBase):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
return [
WorkflowToolParameterConfiguration.model_validate(config)
for config in json.loads(self.parameter_configuration)
@ -359,7 +361,7 @@ class MCPToolProvider(TypeBase):
except (json.JSONDecodeError, TypeError):
return []
def to_entity(self) -> "MCPProviderEntity":
def to_entity(self) -> MCPProviderEntity:
"""Convert to domain entity"""
from core.entities.mcp_provider import MCPProviderEntity
@ -533,5 +535,5 @@ class DeprecatedPublishedAppTool(TypeBase):
)
@property
def description_i18n(self) -> "I18nObject":
def description_i18n(self) -> I18nObject:
return I18nObject.model_validate(json.loads(self.description))

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@ -67,7 +69,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> "WorkflowType":
def value_of(cls, value: str) -> WorkflowType:
"""
Get value of given mode.
@ -80,7 +82,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
"""
Get workflow type from app mode.
@ -181,7 +183,7 @@ class Workflow(Base): # bug
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> "Workflow":
) -> Workflow:
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@ -619,7 +621,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
pause: Mapped[WorkflowPause | None] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@ -689,7 +691,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@ -841,7 +843,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@ -851,13 +853,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@ -932,7 +934,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@ -1046,7 +1048,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
file: Mapped[Optional["UploadFile"]] = orm.relationship(
file: Mapped[UploadFile | None] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1064,7 +1066,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
"""
Get value of given mode.
@ -1181,7 +1183,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable":
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
obj = cls(
id=variable.id,
app_id=app_id,
@ -1334,7 +1336,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@ -1504,7 +1506,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
) -> "WorkflowDraftVariable":
) -> WorkflowDraftVariable:
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@ -1527,7 +1529,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
) -> "WorkflowDraftVariable":
) -> WorkflowDraftVariable:
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@ -1548,7 +1550,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
) -> "WorkflowDraftVariable":
) -> WorkflowDraftVariable:
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@ -1571,7 +1573,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
) -> "WorkflowDraftVariable":
) -> WorkflowDraftVariable:
variable = cls._new(
app_id=app_id,
node_id=node_id,
@ -1667,7 +1669,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
upload_file: Mapped["UploadFile"] = orm.relationship(
upload_file: Mapped[UploadFile] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@ -1734,7 +1736,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
workflow_run: Mapped[WorkflowRun] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@ -1790,7 +1792,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Mapping
@ -106,7 +108,7 @@ class VariableTruncator(BaseTruncator):
self._max_size_bytes = max_size_bytes
@classmethod
def default(cls) -> "VariableTruncator":
def default(cls) -> VariableTruncator:
return VariableTruncator(
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime
import json
from dataclasses import dataclass
@ -78,7 +80,7 @@ class WebsiteCrawlApiRequest:
return CrawlRequest(url=self.url, provider=self.provider, options=options)
@classmethod
def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
def from_args(cls, args: dict) -> WebsiteCrawlApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
url = args.get("url")
@ -102,7 +104,7 @@ class WebsiteCrawlStatusApiRequest:
job_id: str
@classmethod
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest:
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
if not provider:

View File

@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing
the behavior of mock nodes during testing.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
@ -95,67 +97,67 @@ class MockConfigBuilder:
def __init__(self) -> None:
self._config = MockConfig()
def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder":
def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder:
"""Enable or disable auto-mocking."""
self._config.enable_auto_mock = enabled
return self
def with_delays(self, enabled: bool = True) -> "MockConfigBuilder":
def with_delays(self, enabled: bool = True) -> MockConfigBuilder:
"""Enable or disable simulated execution delays."""
self._config.simulate_delays = enabled
return self
def with_llm_response(self, response: str) -> "MockConfigBuilder":
def with_llm_response(self, response: str) -> MockConfigBuilder:
"""Set default LLM response."""
self._config.default_llm_response = response
return self
def with_agent_response(self, response: str) -> "MockConfigBuilder":
def with_agent_response(self, response: str) -> MockConfigBuilder:
"""Set default agent response."""
self._config.default_agent_response = response
return self
def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder:
"""Set default tool response."""
self._config.default_tool_response = response
return self
def with_retrieval_response(self, response: str) -> "MockConfigBuilder":
def with_retrieval_response(self, response: str) -> MockConfigBuilder:
"""Set default retrieval response."""
self._config.default_retrieval_response = response
return self
def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder:
"""Set default HTTP response."""
self._config.default_http_response = response
return self
def with_template_transform_response(self, response: str) -> "MockConfigBuilder":
def with_template_transform_response(self, response: str) -> MockConfigBuilder:
"""Set default template transform response."""
self._config.default_template_transform_response = response
return self
def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder:
"""Set default code execution response."""
self._config.default_code_response = response
return self
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder":
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder:
"""Set outputs for a specific node."""
self._config.set_node_outputs(node_id, outputs)
return self
def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder":
def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder:
"""Set error for a specific node."""
self._config.set_node_error(node_id, error)
return self
def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder":
def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder:
"""Add a node-specific configuration."""
self._config.set_node_config(config.node_id, config)
return self
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder":
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder:
"""Set default configuration for a node type."""
self._config.set_default_config(node_type, config)
return self

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import sys
import types
from collections.abc import Generator
@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only
@pytest.fixture
def tool_node(monkeypatch) -> "ToolNode":
def tool_node(monkeypatch) -> ToolNode:
module_name = "core.ops.ops_trace_manager"
if module_name not in sys.modules:
ops_stub = types.ModuleType(module_name)
@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
return events, stop.value
def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
def _identity_transform(messages, *_args, **_kwargs):
return messages
@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l
return _collect_events(generator)
def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
file_obj = File(
tenant_id="tenant-id",
type=FileType.DOCUMENT,
@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
assert files_segment.value == [file_obj]
def test_plain_link_messages_remain_links(tool_node: "ToolNode"):
def test_plain_link_messages_remain_links(tool_node: ToolNode):
message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),