diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index d636548f2b..a258144d35 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): system_variables=system_inputs, user_inputs=inputs, environment_variables=self._workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. conversation_variables=conversation_variables, ) @@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): trace_manager=app_generate_entity.trace_manager, ) - def _initialize_conversation_variables(self) -> list[VariableUnion]: + def _initialize_conversation_variables(self) -> list[Variable]: """ Initialize conversation variables for the current conversation. @@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation_variables = [var.to_variable() for var in existing_variables] session.commit() - return cast(list[VariableUnion], conversation_variables) + return cast(list[Variable], conversation_variables) def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: """ diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index 77cc00bdc9..c070845b73 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,6 +1,6 @@ import logging -from core.variables import Variable +from core.variables import VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType @@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): if selector[0] != CONVERSATION_VARIABLE_NODE_ID: continue variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): logger.warning( "Conversation variable not found in variable pool. selector=%s", selector, diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 7a1cbf9940..7498224923 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -30,6 +30,7 @@ from .variables import ( SecretVariable, StringVariable, Variable, + VariableBase, ) __all__ = [ @@ -62,4 +63,5 @@ __all__ = [ "StringSegment", "StringVariable", "Variable", + "VariableBase", ] diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 406b4e6f93..8330f1fe19 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None: # - All variants in `SegmentUnion` must inherit from the `Segment` class. # - The union must include all non-abstract subclasses of `Segment`, except: # - `SegmentGroup`, which is not added to the variable pool. -# - `Variable` and its subclasses, which are handled by `VariableUnion`. +# - `VariableBase` and its subclasses, which are handled by `Variable`. SegmentUnion: TypeAlias = Annotated[ ( Annotated[NoneSegment, Tag(SegmentType.NONE)] diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 9fd0bbc5b2..a19c53918d 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -27,7 +27,7 @@ from .segments import ( from .types import SegmentType -class Variable(Segment): +class VariableBase(Segment): """ A variable is a segment that has a name. @@ -45,23 +45,23 @@ class Variable(Segment): selector: Sequence[str] = Field(default_factory=list) -class StringVariable(StringSegment, Variable): +class StringVariable(StringSegment, VariableBase): pass -class FloatVariable(FloatSegment, Variable): +class FloatVariable(FloatSegment, VariableBase): pass -class IntegerVariable(IntegerSegment, Variable): +class IntegerVariable(IntegerSegment, VariableBase): pass -class ObjectVariable(ObjectSegment, Variable): +class ObjectVariable(ObjectSegment, VariableBase): pass -class ArrayVariable(ArraySegment, Variable): +class ArrayVariable(ArraySegment, VariableBase): pass @@ -89,16 +89,16 @@ class SecretVariable(StringVariable): return encrypter.obfuscated_token(self.value) -class NoneVariable(NoneSegment, Variable): +class NoneVariable(NoneSegment, VariableBase): value_type: SegmentType = SegmentType.NONE value: None = None -class FileVariable(FileSegment, Variable): +class FileVariable(FileSegment, VariableBase): pass -class BooleanVariable(BooleanSegment, Variable): +class BooleanVariable(BooleanSegment, VariableBase): pass @@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel): value: Any -# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Variable` for type hinting when serialization is not required. +# The `Variable` type is used to enable serialization and deserialization with Pydantic. +# Use `VariableBase` for type hinting when serialization is not required. # # Note: -# - All variants in `VariableUnion` must inherit from the `Variable` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -VariableUnion: TypeAlias = Annotated[ +# - All variants in `Variable` must inherit from the `VariableBase` class. +# - The union must include all non-abstract subclasses of `VariableBase`. +Variable: TypeAlias = Annotated[ ( Annotated[NoneVariable, Tag(SegmentType.NONE)] | Annotated[StringVariable, Tag(SegmentType.STRING)] diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py index fd78248c17..75f47691da 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/core/workflow/conversation_variable_updater.py @@ -1,7 +1,7 @@ import abc from typing import Protocol -from core.variables import Variable +from core.variables import VariableBase class ConversationVariableUpdater(Protocol): @@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol): """ @abc.abstractmethod - def update(self, conversation_id: str, variable: "Variable"): + def update(self, conversation_id: str, variable: "VariableBase"): """ Updates the value of the specified conversation variable in the underlying storage. :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `Variable` instance containing the updated value. + :param variable: The `VariableBase` instance containing the updated value. """ pass diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 6dce03c94d..41276eb444 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.variables.variables import VariableUnion +from core.variables.variables import Variable class CommandType(StrEnum): @@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand): class VariableUpdate(BaseModel): """Represents a single variable update instruction.""" - value: VariableUnion = Field(description="New variable value") + value: Variable = Field(description="New variable value") class UpdateVariablesCommand(GraphEngineCommand): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e5d86414c1..91df2e4e0b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -11,7 +11,7 @@ from typing_extensions import TypeIs from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment -from core.variables.variables import VariableUnion +from core.variables.variables import Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import ( NodeExecutionType, @@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): datetime, list[GraphNodeEventBase], object | None, - dict[str, VariableUnion], + dict[str, Variable], LLMUsage, ] ], @@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): item: object, flask_app: Flask, context_vars: contextvars.Context, - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]: + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: + def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: + def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: parent_pool = self.graph_runtime_state.variable_pool parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index ac2870aa65..9f5818f4bb 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, Variable): + if not isinstance(original_variable, VariableBase): raise VariableOperatorNodeError("assigned variable not found") match self.node_data.write_mode: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 486e6bb6a7..5857702e72 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -2,7 +2,7 @@ import json from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, Variable +from core.variables import SegmentType, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): # ==================== Validation Part # Check if variable exists - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=item.variable_selector) # Check if operation is supported @@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): for selector in updated_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, Variable): + if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value @@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): def _handle_item( self, *, - variable: Variable, + variable: VariableBase, operation: Operation, value: Any, ): diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 85ceb9d59e..d205c6ac8f 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager -from core.variables import Segment, SegmentGroup, Variable +from core.variables import Segment, SegmentGroup, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import FileSegment, ObjectSegment -from core.variables.variables import RAGPipelineVariableInput, VariableUnion +from core.variables.variables import RAGPipelineVariableInput, Variable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, @@ -32,7 +32,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) @@ -46,13 +46,13 @@ class VariablePool(BaseModel): description="System variables", default_factory=SystemVariable.empty, ) - environment_variables: Sequence[VariableUnion] = Field( + environment_variables: Sequence[Variable] = Field( description="Environment variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) - conversation_variables: Sequence[VariableUnion] = Field( + conversation_variables: Sequence[Variable] = Field( description="Conversation variables.", - default_factory=list[VariableUnion], + default_factory=list[Variable], ) rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( description="RAG pipeline variables.", @@ -105,7 +105,7 @@ class VariablePool(BaseModel): f"got {len(selector)} elements" ) - if isinstance(value, Variable): + if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): variable = variable_factory.segment_to_variable(segment=value, selector=selector) @@ -114,9 +114,9 @@ class VariablePool(BaseModel): variable = variable_factory.segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + self.variable_dictionary[node_id][name] = cast(Variable, variable) @classmethod def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index ea0bdc3537..7992785fe1 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -2,7 +2,7 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.variables import Variable +from core.variables import VariableBase from core.variables.consts import SELECTORS_LENGTH from core.workflow.runtime import VariablePool @@ -26,7 +26,7 @@ class VariableLoader(Protocol): """ @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: """Load variables based on the provided selectors. If the selectors are empty, this method should return an empty list. @@ -36,7 +36,7 @@ class VariableLoader(Protocol): :param: selectors: a list of string list, each inner list should have at least two elements: - the first element is the node ID, - the second element is the variable name. - :return: a list of Variable objects that match the provided selectors. + :return: a list of VariableBase objects that match the provided selectors. """ pass @@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader): Serves as a placeholder when no variable loading is needed. """ - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: return [] diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 494194369a..3f030ae127 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -38,7 +38,7 @@ from core.variables.variables import ( ObjectVariable, SecretVariable, StringVariable, - Variable, + VariableBase, ) from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, @@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = { } -def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) -def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("name"): raise VariableError("missing name") return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) -def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: +def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: if not mapping.get("variable"): raise VariableError("missing variable") return mapping["variable"] -def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: +def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase: """ This factory function is used to create the environment variable or the conversation variable, not support the File type. @@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if (value := mapping.get("value")) is None: raise VariableError("missing value") - result: Variable + result: VariableBase match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return cast(Variable, result) + return cast(VariableBase, result) def build_segment(value: Any, /) -> Segment: @@ -285,8 +285,8 @@ def segment_to_variable( id: str | None = None, name: str | None = None, description: str = "", -) -> Variable: - if isinstance(segment, Variable): +) -> VariableBase: + if isinstance(segment, VariableBase): return segment name = name or selector[-1] id = id or str(uuid4()) @@ -297,7 +297,7 @@ def segment_to_variable( variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] return cast( - Variable, + VariableBase, variable_class( id=id, name=name, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index d037b0c442..2755f77f61 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields from core.helper import encrypter -from core.variables import SecretVariable, SegmentType, Variable +from core.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw): "value_type": value.value_type.value, "description": value.description, } - if isinstance(value, Variable): + if isinstance(value, VariableBase): return { "id": value.id, "name": value.name, diff --git a/api/models/workflow.py b/api/models/workflow.py index 072c6100b5..5d92da3fa1 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,11 +1,9 @@ -from __future__ import annotations - import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -46,7 +44,7 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Segment, SegmentType, Variable +from core.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper @@ -69,7 +67,7 @@ class WorkflowType(StrEnum): RAG_PIPELINE = "rag-pipeline" @classmethod - def value_of(cls, value: str) -> WorkflowType: + def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. @@ -82,7 +80,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. @@ -178,12 +176,12 @@ class Workflow(Base): # bug graph: str, features: str, created_by: str, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list[dict], marked_name: str = "", marked_comment: str = "", - ) -> Workflow: + ) -> "Workflow": workflow = Workflow() workflow.id = str(uuid4()) workflow.tenant_id = tenant_id @@ -447,7 +445,7 @@ class Workflow(Base): # bug # decrypt secret variables value def decrypt_func( - var: Variable, + var: VariableBase, ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) @@ -463,7 +461,7 @@ class Workflow(Base): # bug return decrypted_results @environment_variables.setter - def environment_variables(self, value: Sequence[Variable]): + def environment_variables(self, value: Sequence[VariableBase]): if not value: self._environment_variables = "{}" return @@ -487,7 +485,7 @@ class Workflow(Base): # bug value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value - def encrypt_func(var: Variable) -> Variable: + def encrypt_func(var: VariableBase) -> VariableBase: if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) else: @@ -517,7 +515,7 @@ class Workflow(Base): # bug return result @property - def conversation_variables(self) -> Sequence[Variable]: + def conversation_variables(self) -> Sequence[VariableBase]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: self._conversation_variables = "{}" @@ -527,7 +525,7 @@ class Workflow(Base): # bug return results @conversation_variables.setter - def conversation_variables(self, value: Sequence[Variable]): + def conversation_variables(self, value: Sequence[VariableBase]): self._conversation_variables = json.dumps( {var.name: var.model_dump() for var in value}, ensure_ascii=False, @@ -622,7 +620,7 @@ class WorkflowRun(Base): finished_at: Mapped[datetime | None] = mapped_column(DateTime) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) - pause: Mapped[WorkflowPause | None] = orm.relationship( + pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( "WorkflowPause", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", uselist=False, @@ -692,7 +690,7 @@ class WorkflowRun(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: + def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": return cls( id=data.get("id"), tenant_id=data.get("tenant_id"), @@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) - offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( + offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship( "WorkflowNodeExecutionOffload", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", uselist=True, @@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def preload_offload_data( - query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], + query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], ): return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) @staticmethod def preload_offload_data_and_files( - query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], + query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"], ): return query.options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( @@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property @@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base): back_populates="offload_data", ) - file: Mapped[UploadFile | None] = orm.relationship( + file: Mapped[Optional["UploadFile"]] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum): INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. @@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase): ) @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: + def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, @@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase): ) return obj - def to_variable(self) -> Variable: + def to_variable(self) -> VariableBase: mapping = json.loads(self.data) return variable_factory.build_conversation_variable_from_mapping(mapping) @@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base): ) # Relationship to WorkflowDraftVariableFile - variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( + variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship( foreign_keys=[file_id], lazy="raise", uselist=False, @@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base): node_execution_id: str | None, description: str = "", file_id: str | None = None, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = WorkflowDraftVariable() variable.id = str(uuid4()) variable.created_at = naive_utc_now() @@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base): name: str, value: Segment, description: str = "", - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, @@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base): value: Segment, node_execution_id: str, editable: bool = False, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, @@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base): visible: bool = True, editable: bool = True, file_id: str | None = None, - ) -> WorkflowDraftVariable: + ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=node_id, @@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base): ) # Relationship to UploadFile - upload_file: Mapped[UploadFile] = orm.relationship( + upload_file: Mapped["UploadFile"] = orm.relationship( foreign_keys=[upload_file_id], lazy="raise", uselist=False, @@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base): state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) # Relationship to WorkflowRun - workflow_run: Mapped[WorkflowRun] = orm.relationship( + workflow_run: Mapped["WorkflowRun"] = orm.relationship( foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", @@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: + def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": if isinstance(pause_reason, HumanInputRequired): return cls( type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index acc0ec2b22..92008d5ff1 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from core.variables.variables import Variable +from core.variables.variables import VariableBase from models import ConversationVariable @@ -13,7 +13,7 @@ class ConversationVariableUpdater: def __init__(self, session_maker: sessionmaker[Session]) -> None: self._session_maker: sessionmaker[Session] = session_maker - def update(self, conversation_id: str, variable: Variable) -> None: + def update(self, conversation_id: str, variable: VariableBase) -> None: stmt = select(ConversationVariable).where( ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id ) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1ba64813ba..2d8418900c 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,7 +36,7 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import Variable +from core.variables.variables import VariableBase from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -270,8 +270,8 @@ class RagPipelineService: graph: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], rag_pipeline_variables: list, ) -> Workflow: """ diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 9407a2b3f0..70b0190231 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File -from core.variables import Segment, StringSegment, Variable +from core.variables import Segment, StringSegment, VariableBase from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import ( ArrayFileSegment, @@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader): # Application ID for which variables are being loaded. _app_id: str _tenant_id: str - _fallback_variables: Sequence[Variable] + _fallback_variables: Sequence[VariableBase] def __init__( self, engine: Engine, app_id: str, tenant_id: str, - fallback_variables: Sequence[Variable] | None = None, + fallback_variables: Sequence[VariableBase] | None = None, ): self._engine = engine self._app_id = app_id @@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader): def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: return (selector[0], selector[1]) - def load_variables(self, selectors: list[list[str]]) -> list[Variable]: + def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: if not selectors: return [] - # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. - variable_by_selector: dict[tuple[str, str], Variable] = {} + # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance. + variable_by_selector: dict[tuple[str, str], VariableBase] = {} with Session(bind=self._engine, expire_on_commit=False) as session: srv = WorkflowDraftVariableService(session) @@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader): return list(variable_by_selector.values()) - def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]: + def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]: # This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable` # and must remain synchronized with it. # Ideally, these should be co-located for better maintainability. diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b45a167b73..d8c3159178 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory -from core.variables import Variable -from core.variables.variables import VariableUnion +from core.variables import VariableBase +from core.variables.variables import Variable from core.workflow.entities import WorkflowNodeExecution from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError @@ -198,8 +198,8 @@ class WorkflowService: features: dict, unique_hash: str | None, account: Account, - environment_variables: Sequence[Variable], - conversation_variables: Sequence[Variable], + environment_variables: Sequence[VariableBase], + conversation_variables: Sequence[VariableBase], ) -> Workflow: """ Sync draft workflow @@ -1044,7 +1044,7 @@ def _setup_variable_pool( workflow: Workflow, node_type: NodeType, conversation_id: str, - conversation_variables: list[Variable], + conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if node_type == NodeType.START or node_type.is_trigger_node: @@ -1070,9 +1070,9 @@ def _setup_variable_pool( system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - # Based on the definition of `VariableUnion`, - # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - conversation_variables=cast(list[VariableUnion], conversation_variables), # + # Based on the definition of `Variable`, + # `VariableBase` instances can be safely used as `Variable` since they are compatible. + conversation_variables=cast(list[Variable], conversation_variables), # ) return variable_pool diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index af4f96ba23..aa16c8af1c 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -35,7 +35,6 @@ from core.variables.variables import ( SecretVariable, StringVariable, Variable, - VariableUnion, ) from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable @@ -96,7 +95,7 @@ class _Segments(BaseModel): class _Variables(BaseModel): - variables: list[VariableUnion] + variables: list[Variable] def create_test_file( @@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad: # Create one instance of each variable type test_file = create_test_file() - all_variables: list[VariableUnion] = [ + all_variables: list[Variable] = [ NoneVariable(name="none_var"), StringVariable(value="test string", name="string_var"), IntegerVariable(value=42, name="int_var"), diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 925142892c..fb4b18b57a 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -11,7 +11,7 @@ from core.variables import ( SegmentType, StringVariable, ) -from core.variables.variables import Variable +from core.variables.variables import VariableBase def test_frozen_variables(): @@ -76,7 +76,7 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var: Variable = StringVariable(name="text", value="text") + var: VariableBase = StringVariable(name="text", value="text") assert var.to_object() == "text" var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9733bf60eb..b8869dbf1d 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -24,7 +24,7 @@ from core.variables.variables import ( IntegerVariable, ObjectVariable, StringVariable, - VariableUnion, + Variable, ) from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.runtime import VariablePool @@ -160,7 +160,7 @@ class TestVariablePoolSerialization: ) # Create environment variables with all types including ArrayFileVariable - env_vars: list[VariableUnion] = [ + env_vars: list[Variable] = [ StringVariable( id="env_string_id", name="env_string", @@ -182,7 +182,7 @@ class TestVariablePoolSerialization: ] # Create conversation variables with complex data - conv_vars: list[VariableUnion] = [ + conv_vars: list[Variable] = [ StringVariable( id="conv_string_id", name="conv_string",