From 6a9e0b10055acbb3c64222d74ae09ec836d2071f Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 19 May 2025 22:59:56 +0800 Subject: [PATCH] feat(api): Introduce `WorkflowDraftVariable` Model (#19737) - Introduce `WorkflowDraftVariable` model and the corresponding migration. - Implement `EnumText`, a custom column type for SQLAlchemy designed to work seamlessly with enumeration classes based on `StrEnum`. --- ...hemy_workflow_node_execution_repository.py | 12 +- api/core/variables/consts.py | 7 + api/core/variables/utils.py | 8 + ...e1f5dfb_add_workflowdraftvariable_model.py | 51 ++++ api/models/enums.py | 7 + api/models/types.py | 53 ++++- api/models/workflow.py | 218 +++++++++++++++++- .../unit_tests/models/test_types_enum_text.py | 187 +++++++++++++++ 8 files changed, 533 insertions(+), 10 deletions(-) create mode 100644 api/core/variables/consts.py create mode 100644 api/core/variables/utils.py create mode 100644 api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py create mode 100644 api/tests/unit_tests/models/test_types_enum_text.py diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index af7b261135..3bf775db13 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -4,13 +4,14 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository. import json import logging -from collections.abc import Sequence -from typing import Optional, Union +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_execution_entities import ( NodeExecution, NodeExecutionStatus, @@ -122,7 +123,12 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) status=status, error=db_model.error, elapsed_time=db_model.elapsed_time, - metadata=metadata, + # FIXME(QuantumGhost): a temporary workaround for the following type check failure in Python 3.11. + # However, this problem is not occurred in Python 3.12. + # + # A case of this error is: + # https://github.com/langgenius/dify/actions/runs/15112698604/job/42475659482?pr=19737#step:9:24 + metadata=cast(Mapping[NodeRunMetadataKey, Any] | None, metadata), created_at=db_model.created_at, finished_at=db_model.finished_at, ) diff --git a/api/core/variables/consts.py b/api/core/variables/consts.py new file mode 100644 index 0000000000..03b277d619 --- /dev/null +++ b/api/core/variables/consts.py @@ -0,0 +1,7 @@ +# The minimal selector length for valid variables. +# +# The first element of the selector is the node id, and the second element is the variable name. +# +# If the selector length is more than 2, the remaining parts are the keys / indexes paths used +# to extract part of the variable value. +MIN_SELECTORS_LENGTH = 2 diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py new file mode 100644 index 0000000000..e5d222af7d --- /dev/null +++ b/api/core/variables/utils.py @@ -0,0 +1,8 @@ +from collections.abc import Iterable, Sequence + + +def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: + selectors = [node_id, name] + if paths: + selectors.extend(paths) + return selectors diff --git a/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py new file mode 100644 index 0000000000..5bf394b21c --- /dev/null +++ b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py @@ -0,0 +1,51 @@ +"""add WorkflowDraftVariable model + +Revision ID: 2adcbe1f5dfb +Revises: d28f2004b072 +Create Date: 2025-05-15 15:31:03.128680 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "2adcbe1f5dfb" +down_revision = "d28f2004b072" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "workflow_draft_variables", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("last_edited_at", sa.DateTime(), nullable=True), + sa.Column("node_id", sa.String(length=255), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.String(length=255), nullable=False), + sa.Column("selector", sa.String(length=255), nullable=False), + sa.Column("value_type", sa.String(length=20), nullable=False), + sa.Column("value", sa.Text(), nullable=False), + sa.Column("visible", sa.Boolean(), nullable=False), + sa.Column("editable", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")), + sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")), + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # Dropping `workflow_draft_variables` also drops any index associated with it. + op.drop_table("workflow_draft_variables") + + # ### end Alembic commands ### diff --git a/api/models/enums.py b/api/models/enums.py index 7d9f6068bb..4434c3fec8 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -14,3 +14,10 @@ class UserFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" + + +class DraftVariableType(StrEnum): + # node means that the correspond variable + NODE = "node" + SYS = "sys" + CONVERSATION = "conversation" diff --git a/api/models/types.py b/api/models/types.py index cb6773e70c..e5581c3ab0 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,4 +1,7 @@ -from sqlalchemy import CHAR, TypeDecorator +import enum +from typing import Generic, TypeVar + +from sqlalchemy import CHAR, VARCHAR, TypeDecorator from sqlalchemy.dialects.postgresql import UUID @@ -24,3 +27,51 @@ class StringUUID(TypeDecorator): if value is None: return value return str(value) + + +_E = TypeVar("_E", bound=enum.StrEnum) + + +class EnumText(TypeDecorator, Generic[_E]): + impl = VARCHAR + cache_ok = True + + _length: int + _enum_class: type[_E] + + def __init__(self, enum_class: type[_E], length: int | None = None): + self._enum_class = enum_class + max_enum_value_len = max(len(e.value) for e in enum_class) + if length is not None: + if length < max_enum_value_len: + raise ValueError("length should be greater than enum value length.") + self._length = length + else: + # leave some rooms for future longer enum values. + self._length = max(max_enum_value_len, 20) + + def process_bind_param(self, value: _E | str | None, dialect): + if value is None: + return value + if isinstance(value, self._enum_class): + return value.value + elif isinstance(value, str): + self._enum_class(value) + return value + else: + raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(VARCHAR(self._length)) + + def process_result_value(self, value, dialect) -> _E | None: + if value is None: + return value + if not isinstance(value, str): + raise TypeError(f"expected str, got {type(value)}") + return self._enum_class(value) + + def compare_values(self, x, y): + if x is None or y is None: + return x is y + return x == y diff --git a/api/models/workflow.py b/api/models/workflow.py index fd0d279d50..a81c889277 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,29 +1,36 @@ import json +import logging from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Self, Union from uuid import uuid4 +from core.variables import utils as variable_utils +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment + if TYPE_CHECKING: from models.model import AppMode import sqlalchemy as sa -from sqlalchemy import func +from sqlalchemy import UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column import contexts from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Variable +from core.variables import SecretVariable, Segment, SegmentType, Variable from factories import variable_factory from libs import helper from .account import Account from .base import Base from .engine import db -from .enums import CreatorUserRole -from .types import StringUUID +from .enums import CreatorUserRole, DraftVariableType +from .types import EnumText, StringUUID + +_logger = logging.getLogger(__name__) if TYPE_CHECKING: from models.model import AppMode @@ -651,7 +658,7 @@ class WorkflowNodeExecution(Base): return json.loads(self.inputs) if self.inputs else None @property - def outputs_dict(self): + def outputs_dict(self) -> dict[str, Any] | None: return json.loads(self.outputs) if self.outputs else None @property @@ -659,7 +666,7 @@ class WorkflowNodeExecution(Base): return json.loads(self.process_data) if self.process_data else None @property - def execution_metadata_dict(self): + def execution_metadata_dict(self) -> dict[str, Any] | None: return json.loads(self.execution_metadata) if self.execution_metadata else None @property @@ -797,3 +804,202 @@ class ConversationVariable(Base): def to_variable(self) -> Variable: mapping = json.loads(self.data) return variable_factory.build_conversation_variable_from_mapping(mapping) + + +# Only `sys.query` and `sys.files` could be modified. +_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) + + +def _naive_utc_datetime(): + return datetime.now(UTC).replace(tzinfo=None) + + +class WorkflowDraftVariable(Base): + @staticmethod + def unique_columns() -> list[str]: + return [ + "app_id", + "node_id", + "name", + ] + + __tablename__ = "workflow_draft_variables" + __table_args__ = (UniqueConstraint(*unique_columns()),) + + # id is the unique identifier of a draft variable. + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + + created_at = mapped_column( + db.DateTime, + nullable=False, + default=_naive_utc_datetime, + server_default=func.current_timestamp(), + ) + + updated_at = mapped_column( + db.DateTime, + nullable=False, + default=_naive_utc_datetime, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + # "`app_id` maps to the `id` field in the `model.App` model." + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # `last_edited_at` records when the value of a given draft variable + # is edited. + # + # If it's not edited after creation, its value is `None`. + last_edited_at: Mapped[datetime | None] = mapped_column( + db.DateTime, + nullable=True, + default=None, + ) + + # The `node_id` field is special. + # + # If the variable is a conversation variable or a system variable, then the value of `node_id` + # is `conversation` or `sys`, respective. + # + # Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is + # the identity of correspond node in graph definition. An example of node id is `"1745769620734"`. + # + # However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other + # "Answer" node conform the rules above.) + node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id") + + # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than + # 80 chars. + # + # ref: api/core/workflow/entities/variable_pool.py:18 + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + description: Mapped[str] = mapped_column( + sa.String(255), + default="", + nullable=False, + ) + + selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector") + + value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20)) + # JSON string + value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value") + + # visible + visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) + editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) + + def get_selector(self) -> list[str]: + selector = json.loads(self.selector) + if not isinstance(selector, list): + _logger.error( + "invalid selector loaded from database, type=%s, value=%s", + type(selector), + self.selector, + ) + raise ValueError("invalid selector.") + return selector + + def _set_selector(self, value: list[str]): + self.selector = json.dumps(value) + + def get_value(self) -> Segment | None: + return build_segment(json.loads(self.value)) + + def set_name(self, name: str): + self.name = name + self._set_selector([self.node_id, name]) + + def set_value(self, value: Segment): + self.value = json.dumps(value.value) + self.value_type = value.value_type + + def get_node_id(self) -> str | None: + if self.get_variable_type() == DraftVariableType.NODE: + return self.node_id + else: + return None + + def get_variable_type(self) -> DraftVariableType: + match self.node_id: + case DraftVariableType.CONVERSATION: + return DraftVariableType.CONVERSATION + case DraftVariableType.SYS: + return DraftVariableType.SYS + case _: + return DraftVariableType.NODE + + @classmethod + def _new( + cls, + *, + app_id: str, + node_id: str, + name: str, + value: Segment, + description: str = "", + ) -> "WorkflowDraftVariable": + variable = WorkflowDraftVariable() + variable.created_at = _naive_utc_datetime() + variable.updated_at = _naive_utc_datetime() + variable.description = description + variable.app_id = app_id + variable.node_id = node_id + variable.name = name + variable.app_id = app_id + variable.set_value(value) + variable._set_selector(list(variable_utils.to_selector(node_id, name))) + return variable + + @classmethod + def new_conversation_variable( + cls, + *, + app_id: str, + name: str, + value: Segment, + ) -> "WorkflowDraftVariable": + variable = cls._new( + app_id=app_id, + node_id=CONVERSATION_VARIABLE_NODE_ID, + name=name, + value=value, + ) + return variable + + @classmethod + def new_sys_variable( + cls, + *, + app_id: str, + name: str, + value: Segment, + editable: bool = False, + ) -> "WorkflowDraftVariable": + variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) + variable.editable = editable + return variable + + @classmethod + def new_node_variable( + cls, + *, + app_id: str, + node_id: str, + name: str, + value: Segment, + visible: bool = True, + ) -> "WorkflowDraftVariable": + variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) + variable.visible = visible + variable.editable = True + return variable + + @property + def edited(self): + return self.last_edited_at is not None + + +def is_system_variable_editable(name: str) -> bool: + return name in _EDITABLE_SYSTEM_VARIABLE diff --git a/api/tests/unit_tests/models/test_types_enum_text.py b/api/tests/unit_tests/models/test_types_enum_text.py new file mode 100644 index 0000000000..3afa0f17a0 --- /dev/null +++ b/api/tests/unit_tests/models/test_types_enum_text.py @@ -0,0 +1,187 @@ +from collections.abc import Callable, Iterable +from enum import StrEnum +from typing import Any, NamedTuple, TypeVar + +import pytest +import sqlalchemy as sa +from sqlalchemy import exc as sa_exc +from sqlalchemy import insert +from sqlalchemy.orm import DeclarativeBase, Mapped, Session +from sqlalchemy.sql.sqltypes import VARCHAR + +from models.types import EnumText + +_user_type_admin = "admin" +_user_type_normal = "normal" + + +class _Base(DeclarativeBase): + pass + + +class _UserType(StrEnum): + admin = _user_type_admin + normal = _user_type_normal + + +class _EnumWithLongValue(StrEnum): + unknown = "unknown" + a_really_long_enum_values = "a_really_long_enum_values" + + +class _User(_Base): + __tablename__ = "users" + + id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) + name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False) + user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) + user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True) + + +class _ColumnTest(_Base): + __tablename__ = "column_test" + + id: Mapped[int] = sa.Column(sa.Integer, primary_key=True) + + user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal) + explicit_length: Mapped[_UserType | None] = sa.Column( + EnumText(_UserType, length=50), nullable=True, default=_UserType.normal + ) + long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False) + + +_T = TypeVar("_T") + + +def _first(it: Iterable[_T]) -> _T: + ls = list(it) + if not ls: + raise ValueError("List is empty") + return ls[0] + + +class TestEnumText: + def test_column_impl(self): + engine = sa.create_engine("sqlite://", echo=False) + _Base.metadata.create_all(engine) + + inspector = sa.inspect(engine) + columns = inspector.get_columns(_ColumnTest.__tablename__) + + user_type_column = _first(c for c in columns if c["name"] == "user_type") + sql_type = user_type_column["type"] + assert isinstance(user_type_column["type"], VARCHAR) + assert sql_type.length == 20 + assert user_type_column["nullable"] is False + + explicit_length_column = _first(c for c in columns if c["name"] == "explicit_length") + sql_type = explicit_length_column["type"] + assert isinstance(sql_type, VARCHAR) + assert sql_type.length == 50 + assert explicit_length_column["nullable"] is True + + long_value_column = _first(c for c in columns if c["name"] == "long_value") + sql_type = long_value_column["type"] + assert isinstance(sql_type, VARCHAR) + assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values) + + def test_insert_and_select(self): + engine = sa.create_engine("sqlite://", echo=False) + _Base.metadata.create_all(engine) + + with Session(engine) as session: + admin_user = _User( + name="admin", + user_type=_UserType.admin, + user_type_nullable=None, + ) + session.add(admin_user) + session.flush() + admin_user_id = admin_user.id + + normal_user = _User( + name="normal", + user_type=_UserType.normal.value, + user_type_nullable=_UserType.normal.value, + ) + session.add(normal_user) + session.flush() + normal_user_id = normal_user.id + session.commit() + + with Session(engine) as session: + user = session.query(_User).filter(_User.id == admin_user_id).first() + assert user.user_type == _UserType.admin + assert user.user_type_nullable is None + + with Session(engine) as session: + user = session.query(_User).filter(_User.id == normal_user_id).first() + assert user.user_type == _UserType.normal + assert user.user_type_nullable == _UserType.normal + + def test_insert_invalid_values(self): + def _session_insert_with_value(sess: Session, user_type: Any): + user = _User(name="test_user", user_type=user_type) + sess.add(user) + sess.flush() + + def _insert_with_user(sess: Session, user_type: Any): + stmt = insert(_User).values( + { + "name": "test_user", + "user_type": user_type, + } + ) + sess.execute(stmt) + + class TestCase(NamedTuple): + name: str + action: Callable[[Session], None] + exc_type: type[Exception] + + engine = sa.create_engine("sqlite://", echo=False) + _Base.metadata.create_all(engine) + cases = [ + TestCase( + name="session insert with invalid value", + action=lambda s: _session_insert_with_value(s, "invalid"), + exc_type=ValueError, + ), + TestCase( + name="session insert with invalid type", + action=lambda s: _session_insert_with_value(s, 1), + exc_type=TypeError, + ), + TestCase( + name="insert with invalid value", + action=lambda s: _insert_with_user(s, "invalid"), + exc_type=ValueError, + ), + TestCase( + name="insert with invalid type", + action=lambda s: _insert_with_user(s, 1), + exc_type=TypeError, + ), + ] + for idx, c in enumerate(cases, 1): + with pytest.raises(sa_exc.StatementError) as exc: + with Session(engine) as session: + c.action(session) + + assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}" + + def test_select_invalid_values(self): + engine = sa.create_engine("sqlite://", echo=False) + _Base.metadata.create_all(engine) + + insertion_sql = """ + INSERT INTO users (id, name, user_type) VALUES + (1, 'invalid_value', 'invalid'); + """ + with Session(engine) as session: + session.execute(sa.text(insertion_sql)) + session.commit() + + with pytest.raises(ValueError) as exc: + with Session(engine) as session: + _user = session.query(_User).filter(_User.id == 1).first()