refactor: replace sa.String with EnumText in mapped_column for type s… (#33332)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
tmimmanuel 2026-03-14 04:38:27 +00:00 committed by GitHub
parent 6043ec4423
commit e64f4d6039
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 218 additions and 138 deletions

View File

@ -43,6 +43,7 @@ from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -231,7 +232,7 @@ class AccountInitApi(Resource):
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = "active"
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@ -12,6 +12,7 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
_logger = logging.getLogger(__name__)
@ -38,7 +39,9 @@ class DatasetIndexToolCallbackHandler:
source="app",
source_app_id=self._app_id,
created_by_role=(
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
CreatorUserRole.ACCOUNT
if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatorUserRole.END_USER
),
created_by=self._user_id,
)

View File

@ -628,10 +628,10 @@ class TraceTask:
if not message_data:
return {}
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0:
conversation_modes = db.session.scalars(conversation_mode_stmt).all()
if not conversation_modes or len(conversation_modes) == 0:
return {}
conversation_mode = conversation_mode[0]
conversation_mode = conversation_modes[0]
created_at = message_data.created_at
inputs = message_data.message

View File

@ -627,7 +627,7 @@ class ProviderManager:
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
provider_type=ProviderType.SYSTEM,
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
quota_used=0,

View File

@ -83,6 +83,7 @@ from models.dataset import (
)
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from models.enums import CreatorUserRole
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
@ -1009,7 +1010,7 @@ class DatasetRetrieval:
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by_role=CreatorUserRole(user_from),
created_by=user_id,
)
dataset_queries.append(dataset_query)

View File

@ -146,7 +146,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# No sequence number generation needed anymore
db_model.type = domain_model.workflow_type
from models.workflow import WorkflowType as ModelWorkflowType
db_model.type = ModelWorkflowType(domain_model.workflow_type.value)
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None

View File

@ -17,7 +17,8 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from models.enums import CreatorUserRole
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)
@ -47,12 +48,28 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.triggered_from = data.get("triggered_from") or ""
triggered_from_val = data.get("triggered_from")
try:
model.triggered_from = (
WorkflowNodeExecutionTriggeredFrom(str(triggered_from_val))
if triggered_from_val
else WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
)
except ValueError:
logger.warning("Invalid triggered_from value: %s, falling back to WORKFLOW_RUN", triggered_from_val)
model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
model.node_id = data.get("node_id") or ""
model.node_type = data.get("node_type") or ""
model.status = data.get("status") or "running" # Default status if missing
model.title = data.get("title") or ""
model.created_by_role = data.get("created_by_role") or ""
created_by_role_val = data.get("created_by_role")
try:
model.created_by_role = (
CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT
)
except ValueError:
logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val)
model.created_by_role = CreatorUserRole.ACCOUNT
model.created_by = data.get("created_by") or ""
model.index = safe_int(data.get("index", 0))

View File

@ -22,12 +22,13 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from dify_graph.enums import WorkflowExecutionStatus
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun, WorkflowType
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.types import (
AverageInteractionStats,
@ -59,11 +60,37 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.type = data.get("type") or ""
model.triggered_from = data.get("triggered_from") or ""
type_val = data.get("type")
try:
model.type = WorkflowType(str(type_val)) if type_val else WorkflowType.WORKFLOW
except ValueError:
logger.warning("Invalid type value: %s, falling back to WORKFLOW", type_val)
model.type = WorkflowType.WORKFLOW
triggered_from_val = data.get("triggered_from")
try:
model.triggered_from = (
WorkflowRunTriggeredFrom(str(triggered_from_val))
if triggered_from_val
else WorkflowRunTriggeredFrom.APP_RUN
)
except ValueError:
logger.warning("Invalid triggered_from value: %s, falling back to APP_RUN", triggered_from_val)
model.triggered_from = WorkflowRunTriggeredFrom.APP_RUN
model.version = data.get("version") or ""
model.status = data.get("status") or "running" # Default status if missing
model.created_by_role = data.get("created_by_role") or ""
status_val = data.get("status")
try:
model.status = WorkflowExecutionStatus(str(status_val)) if status_val else WorkflowExecutionStatus.RUNNING
except ValueError:
logger.warning("Invalid status value: %s, falling back to RUNNING", status_val)
model.status = WorkflowExecutionStatus.RUNNING
created_by_role_val = data.get("created_by_role")
try:
model.created_by_role = (
CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT
)
except ValueError:
logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val)
model.created_by_role = CreatorUserRole.ACCOUNT
model.created_by = data.get("created_by") or ""
model.total_tokens = safe_int(data.get("total_tokens", 0))

View File

@ -8,12 +8,12 @@ from uuid import uuid4
import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated
from .base import TypeBase
from .engine import db
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
class TenantAccountRole(enum.StrEnum):
@ -104,7 +104,9 @@ class Account(UserMixin, TypeBase):
last_active_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
)
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'"), default="active")
status: Mapped[AccountStatus] = mapped_column(
EnumText(AccountStatus, length=16), server_default=sa.text("'active'"), default=AccountStatus.ACTIVE
)
initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
@ -116,12 +118,6 @@ class Account(UserMixin, TypeBase):
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
@validates("status")
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
if isinstance(value, AccountStatus):
return value.value
return value
@property
def is_password_set(self):
return self.password is not None
@ -177,8 +173,7 @@ class Account(UserMixin, TypeBase):
return self.role
def get_status(self) -> AccountStatus:
status_str = self.status
return AccountStatus(status_str)
return self.status
@classmethod
def get_by_openid(cls, provider: str, open_id: str):
@ -249,7 +244,9 @@ class Tenant(TypeBase):
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic")
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'"), default="normal")
status: Mapped[TenantStatus] = mapped_column(
EnumText(TenantStatus, length=255), server_default=sa.text("'normal'"), default=TenantStatus.NORMAL
)
custom_config: Mapped[str | None] = mapped_column(LongText, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False
@ -291,7 +288,9 @@ class TenantAccountJoin(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False)
role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal")
role: Mapped[TenantAccountRole] = mapped_column(
EnumText(TenantAccountRole, length=16), server_default="normal", default=TenantAccountRole.NORMAL
)
invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.current_timestamp(), nullable=False, init=False

View File

@ -30,8 +30,9 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode,
from .account import Account
from .base import Base, TypeBase
from .engine import db
from .enums import CreatorUserRole
from .model import App, Tag, TagBinding, UploadFile
from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index
from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index
logger = logging.getLogger(__name__)
@ -59,7 +60,11 @@ class Dataset(Base):
name: Mapped[str] = mapped_column(String(255))
description = mapped_column(LongText, nullable=True)
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'"))
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'"))
permission: Mapped[DatasetPermissionEnum] = mapped_column(
EnumText(DatasetPermissionEnum, length=255),
server_default=sa.text("'only_me'"),
default=DatasetPermissionEnum.ONLY_ME,
)
data_source_type = mapped_column(String(255))
indexing_technique: Mapped[str | None] = mapped_column(String(255))
index_struct = mapped_column(LongText, nullable=True)
@ -1003,7 +1008,7 @@ class DatasetQuery(TypeBase):
content: Mapped[str] = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False

View File

@ -29,9 +29,9 @@ from libs.uuid_utils import uuidv7
from .account import Account, Tenant
from .base import Base, TypeBase, gen_uuidv4_string
from .engine import db
from .enums import CreatorUserRole
from .enums import CreatorUserRole, MessageStatus
from .provider_ids import GenericProviderID
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
if TYPE_CHECKING:
from .workflow import Workflow
@ -337,8 +337,8 @@ class App(Base):
tenant_id: Mapped[str] = mapped_column(StringUUID)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(LongText, default=sa.text("''"))
mode: Mapped[str] = mapped_column(String(255))
icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link
mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255))
icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255))
icon = mapped_column(String(255))
icon_background: Mapped[str | None] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
@ -1000,7 +1000,7 @@ class Conversation(Base):
model_provider = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(LongText)
model_id = mapped_column(String(255), nullable=True)
mode: Mapped[str] = mapped_column(String(255))
mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255))
name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(LongText)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
@ -1351,7 +1351,12 @@ class Message(Base):
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'"))
status: Mapped[MessageStatus] = mapped_column(
EnumText(MessageStatus, length=255),
nullable=False,
server_default=sa.text("'normal'"),
default=MessageStatus.NORMAL,
)
error: Mapped[str | None] = mapped_column(LongText)
message_metadata: Mapped[str | None] = mapped_column(LongText)
invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
@ -1364,7 +1369,7 @@ class Message(Base):
)
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True)
app_mode: Mapped[AppMode | None] = mapped_column(EnumText(AppMode, length=255), nullable=True)
@property
def inputs(self) -> dict[str, Any]:
@ -1767,7 +1772,7 @@ class MessageFile(TypeBase):
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
@ -2015,7 +2020,7 @@ class Site(Base):
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
icon_type = mapped_column(String(255), nullable=True)
icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255), nullable=True)
icon = mapped_column(String(255))
icon_background = mapped_column(String(255))
description = mapped_column(LongText)
@ -2110,7 +2115,12 @@ class UploadFile(Base):
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
# Its value is derived from the `CreatorUserRole` enumeration.
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'account'"))
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'account'"),
default=CreatorUserRole.ACCOUNT,
)
# The `created_by` field stores the ID of the entity that created this upload file.
#
@ -2163,7 +2173,7 @@ class UploadFile(Base):
self.size = size
self.extension = extension
self.mime_type = mime_type
self.created_by_role = created_by_role.value
self.created_by_role = created_by_role
self.created_by = created_by
self.created_at = created_at
self.used = used
@ -2226,7 +2236,7 @@ class MessageAgentThought(TypeBase):
)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)

View File

@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .types import LongText, StringUUID
from .types import EnumText, LongText, StringUUID
class ProviderType(StrEnum):
@ -69,8 +69,8 @@ class Provider(TypeBase):
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
String(40), nullable=False, server_default=text("'custom'"), default="custom"
provider_type: Mapped[ProviderType] = mapped_column(
EnumText(ProviderType, length=40), nullable=False, server_default=text("'custom'"), default=ProviderType.CUSTOM
)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)

View File

@ -227,7 +227,7 @@ class WorkflowTriggerLog(TypeBase):
queue_name: Mapped[str] = mapped_column(String(100), nullable=False)
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(String(255), nullable=False)
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)

View File

@ -2,13 +2,14 @@ from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
from .engine import db
from .enums import CreatorUserRole
from .model import Message
from .types import StringUUID
from .types import EnumText, StringUUID
class SavedMessage(TypeBase):
@ -24,7 +25,9 @@ class SavedMessage(TypeBase):
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'"))
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'")
)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime,
@ -50,8 +53,8 @@ class PinnedConversation(TypeBase):
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role: Mapped[str] = mapped_column(
String(255),
created_by_role: Mapped[CreatorUserRole] = mapped_column(
EnumText(CreatorUserRole, length=255),
nullable=False,
server_default=sa.text("'end_user'"),
)

View File

@ -53,7 +53,7 @@ from libs import helper
from .account import Account
from .base import Base, DefaultFieldsMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
logger = logging.getLogger(__name__)
@ -141,7 +141,7 @@ class Workflow(Base): # bug
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255), nullable=False)
version: Mapped[str] = mapped_column(String(255), nullable=False)
marked_name: Mapped[str] = mapped_column(String(255), default="", server_default="")
marked_comment: Mapped[str] = mapped_column(String(255), default="", server_default="")
@ -188,7 +188,7 @@ class Workflow(Base): # bug
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
workflow.app_id = app_id
workflow.type = type
workflow.type = WorkflowType(type)
workflow.version = version
workflow.graph = graph
workflow.features = features
@ -608,8 +608,8 @@ class WorkflowRun(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
type: Mapped[str] = mapped_column(String(255))
triggered_from: Mapped[str] = mapped_column(String(255))
type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255))
triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(EnumText(WorkflowRunTriggeredFrom, length=255))
version: Mapped[str] = mapped_column(String(255))
graph: Mapped[str | None] = mapped_column(LongText)
inputs: Mapped[str | None] = mapped_column(LongText)
@ -830,7 +830,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
triggered_from: Mapped[str] = mapped_column(String(255))
triggered_from: Mapped[WorkflowNodeExecutionTriggeredFrom] = mapped_column(
EnumText(WorkflowNodeExecutionTriggeredFrom, length=255)
)
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID)
index: Mapped[int] = mapped_column(sa.Integer)
predecessor_node_id: Mapped[str | None] = mapped_column(String(255))
@ -846,7 +848,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
execution_metadata: Mapped[str | None] = mapped_column(LongText)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
created_by_role: Mapped[str] = mapped_column(String(255))
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255))
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
@ -1130,7 +1132,7 @@ class WorkflowAppLog(TypeBase):
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
@ -1204,7 +1206,7 @@ class WorkflowArchiveLog(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
@ -1213,7 +1215,9 @@ class WorkflowArchiveLog(TypeBase):
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False)
run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(
EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False
)
run_error: Mapped[str | None] = mapped_column(LongText, nullable=True)
run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))

View File

@ -1089,9 +1089,9 @@ class TenantService:
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta:
ta.role = role
ta.role = TenantAccountRole(role)
else:
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole(role))
db.session.add(ta)
db.session.commit()
@ -1319,10 +1319,10 @@ class TenantService:
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
)
if current_owner_join:
current_owner_join.role = "admin"
current_owner_join.role = TenantAccountRole.ADMIN
# Update the role of the target member
target_member_join.role = new_role
target_member_join.role = TenantAccountRole(new_role)
db.session.commit()
@staticmethod

View File

@ -429,17 +429,18 @@ class AppDslService:
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
resolved_icon_type: IconType
if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]:
icon_type = icon_type_value
resolved_icon_type = IconType(icon_type_value)
else:
icon_type = IconType.EMOJI
resolved_icon_type = IconType.EMOJI
icon = icon or str(app_data.get("icon", ""))
if app:
# Update existing app
app.name = name or app_data.get("name", app.name)
app.description = description or app_data.get("description", app.description)
app.icon_type = icon_type
app.icon_type = resolved_icon_type
app.icon = icon
app.icon_background = icon_background or app_data.get("icon_background", app.icon_background)
app.updated_by = account.id
@ -452,10 +453,10 @@ class AppDslService:
app = App()
app.id = str(uuid4())
app.tenant_id = account.current_tenant_id
app.mode = app_mode.value
app.mode = app_mode
app.name = name or app_data.get("name", "")
app.description = description or app_data.get("description", "")
app.icon_type = icon_type
app.icon_type = resolved_icon_type
app.icon = icon
app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF")
app.enable_site = True
@ -549,7 +550,7 @@ class AppDslService:
"kind": "app",
"app": {
"name": app_model.name,
"mode": app_model.mode,
"mode": app_model.mode.value if isinstance(app_model.mode, AppMode) else app_model.mode,
"icon": app_model.icon if app_model.icon_type == "image" else "🤖",
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
"description": app_model.description,

View File

@ -19,7 +19,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.model import App, AppMode, AppModelConfig, IconType, Site
from models.tools import ApiToolProvider
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
@ -254,7 +254,7 @@ class AppService:
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = args["icon_type"]
app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)

View File

@ -254,7 +254,7 @@ class DatasetService:
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model_name if embedding_model else None
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.permission = DatasetPermissionEnum(permission) if permission else DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
if summary_index_setting is not None:
dataset.summary_index_setting = summary_index_setting

View File

@ -13,6 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode
from extensions.ext_database import db
from models import Account
from models.dataset import Dataset, DatasetQuery
from models.enums import CreatorUserRole
logger = logging.getLogger(__name__)
@ -98,7 +99,7 @@ class HitTestingService:
content=json.dumps(dataset_queries),
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
db.session.add(dataset_query)
@ -138,7 +139,7 @@ class HitTestingService:
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)

View File

@ -3,6 +3,7 @@ from typing import Union
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import CreatorUserRole
from models.model import App, EndUser
from models.web import SavedMessage
from services.message_service import MessageService
@ -54,7 +55,7 @@ class SavedMessageService:
saved_message = SavedMessage(
app_id=app_model.id,
message_id=message.id,
created_by_role="account" if isinstance(user, Account) else "end_user",
created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER,
created_by=user.id,
)

View File

@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import CreatorUserRole
from models.model import App, EndUser
from models.web import PinnedConversation
from services.conversation_service import ConversationService
@ -84,7 +85,7 @@ class WebConversationService:
pinned_conversation = PinnedConversation(
app_id=app_model.id,
conversation_id=conversation.id,
created_by_role="account" if isinstance(user, Account) else "end_user",
created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER,
created_by=user.id,
)

View File

@ -24,7 +24,7 @@ from events.app_event import app_was_created
from extensions.ext_database import db
from models import Account
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import App, AppMode, AppModelConfig
from models.model import App, AppMode, AppModelConfig, IconType
from models.workflow import Workflow, WorkflowType
@ -72,7 +72,7 @@ class WorkflowConverter:
new_app.tenant_id = app_model.tenant_id
new_app.name = name or app_model.name + "(workflow)"
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
new_app.icon_type = icon_type or app_model.icon_type
new_app.icon_type = IconType(icon_type) if icon_type else app_model.icon_type
new_app.icon = icon or app_model.icon
new_app.icon_background = icon_background or app_model.icon_background
new_app.enable_site = app_model.enable_site

View File

@ -164,7 +164,7 @@ def _record_trigger_failure_log(
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=created_by_role.value,
created_by_role=created_by_role,
created_by=created_by,
created_at=now,
finished_at=now,
@ -179,7 +179,7 @@ def _record_trigger_failure_log(
workflow_id=workflow.id,
workflow_run_id=workflow_run.id,
created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value,
created_by_role=created_by_role.value,
created_by_role=created_by_role,
created_by=created_by,
)
session.add(workflow_app_log)
@ -212,7 +212,7 @@ def _record_trigger_failure_log(
error=error_message,
queue_name=queue_name,
retry_count=0,
created_by_role=created_by_role.value,
created_by_role=created_by_role,
created_by=created_by,
triggered_at=now,
finished_at=now,

View File

@ -94,13 +94,15 @@ def _create_workflow_run_from_execution(
workflow_run.tenant_id = tenant_id
workflow_run.app_id = app_id
workflow_run.workflow_id = execution.workflow_id
workflow_run.type = execution.workflow_type.value
workflow_run.triggered_from = triggered_from.value
from models.workflow import WorkflowType as ModelWorkflowType
workflow_run.type = ModelWorkflowType(execution.workflow_type.value)
workflow_run.triggered_from = triggered_from
workflow_run.version = execution.workflow_version
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
workflow_run.status = execution.status.value
workflow_run.status = execution.status
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
@ -108,7 +110,7 @@ def _create_workflow_run_from_execution(
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.created_by_role = creator_user_role.value
workflow_run.created_by_role = creator_user_role
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
workflow_run.finished_at = execution.finished_at
@ -121,7 +123,7 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
workflow_run.status = execution.status
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)

View File

@ -98,7 +98,7 @@ def _create_node_execution_from_domain(
node_execution.tenant_id = tenant_id
node_execution.app_id = app_id
node_execution.workflow_id = execution.workflow_id
node_execution.triggered_from = triggered_from.value
node_execution.triggered_from = triggered_from
node_execution.workflow_run_id = execution.workflow_execution_id
node_execution.index = execution.index
node_execution.predecessor_node_id = execution.predecessor_node_id
@ -128,7 +128,7 @@ def _create_node_execution_from_domain(
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role.value
node_execution.created_by_role = creator_user_role
node_execution.created_by = creator_user_id
node_execution.created_at = execution.created_at
node_execution.finished_at = execution.finished_at

View File

@ -165,7 +165,7 @@ class TestChatMessageApiPermissions:
agent_thoughts=[],
message_files=[],
message_metadata_dict={},
status="success",
status="normal",
error="",
parent_message_id=None,
)

View File

@ -3331,7 +3331,7 @@ class TestRegisterService:
TenantService.create_tenant_member(tenant, account, role="normal")
# Change tenant status to non-normal
tenant.status = "suspended"
tenant.status = "archive"
db_session_with_containers.commit()

View File

@ -2,6 +2,7 @@ import uuid
from unittest.mock import ANY, MagicMock, patch
import pytest
import sqlalchemy as sa
from faker import Faker
from sqlalchemy.orm import Session
@ -492,20 +493,20 @@ class TestAppGenerateService:
)
# Manually set invalid mode after creation
# With EnumText, invalid values are rejected at the DB level during flush,
# raising StatementError wrapping ValueError
app.mode = "invalid_mode"
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
# Execute the method under test and expect ValueError
with pytest.raises(ValueError) as exc_info:
# Execute the method under test and expect either ValueError (direct) or
# StatementError (from EnumText validation during autoflush)
with pytest.raises((ValueError, sa.exc.StatementError)):
AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify error message
assert "Invalid app mode" in str(exc_info.value)
def test_generate_with_workflow_id_format_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):

View File

@ -163,7 +163,7 @@ class TestSavedMessageService:
answer_unit_price=0.002,
total_price=0.003,
currency="USD",
status="success",
status="normal",
)
db_session_with_containers.add(message)

View File

@ -62,7 +62,7 @@ class TestWorkflowService:
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="active",
status="normal",
)
tenant.id = account.current_tenant_id
tenant.created_at = fake.date_time_this_year()
@ -1090,20 +1090,19 @@ class TestWorkflowService:
This test ensures that the service correctly handles feature validation
for unsupported app modes, preventing invalid operations.
With EnumText, invalid values are rejected at the DB level during flush,
raising StatementError wrapping ValueError.
"""
# Arrange
fake = Faker()
app = self._create_test_app(db_session_with_containers, fake)
app.mode = "invalid_mode" # Invalid mode
db_session_with_containers.commit()
# Act & Assert - EnumText validation rejects invalid values at DB flush
import sqlalchemy as sa
workflow_service = WorkflowService()
features = {"test": "value"}
# Act & Assert
with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"):
workflow_service.validate_features_structure(app_model=app, features=features)
with pytest.raises((ValueError, sa.exc.StatementError)):
db_session_with_containers.commit()
def test_update_workflow_success(self, db_session_with_containers: Session):
"""

View File

@ -110,7 +110,7 @@ class TestCleanDatasetTask:
tenant = Tenant(
name=fake.company(),
plan="basic",
status="active",
status="normal",
)
db_session_with_containers.add(tenant)

View File

@ -48,7 +48,7 @@ class TestDeleteSegmentFromIndexTask:
Tenant: Created test tenant instance
"""
fake = fake or Faker()
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active")
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal")
tenant.id = fake.uuid4()
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at

View File

@ -65,7 +65,7 @@ class TestDisableSegmentsFromIndexTask:
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="active",
status="normal",
)
tenant.id = account.tenant_id
tenant.created_at = fake.date_time_this_year()

View File

@ -118,7 +118,7 @@ class TestSendEmailCodeLoginMailTask:
tenant = Tenant(
name=fake.company(),
plan="basic",
status="active",
status="normal",
)
db_session_with_containers.add(tenant)

View File

@ -48,7 +48,7 @@ def make_message():
msg.query = "hello"
msg.re_sign_file_url_answer = ""
msg.user_feedback = MagicMock(rating=None)
msg.status = "success"
msg.status = "normal"
msg.error = None
return msg

View File

@ -137,7 +137,7 @@ def test_message_list_mapping(app: Flask) -> None:
{"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"},
message_file_obj,
],
status="success",
status="normal",
error=None,
message_metadata_dict={"meta": "value"},
extra_contents=[

View File

@ -3730,7 +3730,7 @@ class TestDatasetRetrievalAdditionalHelpers:
attachment_ids=None,
dataset_ids=["d1"],
app_id="a1",
user_from="web",
user_from="account",
user_id="u1",
)
mock_session.add_all.assert_not_called()
@ -3740,7 +3740,7 @@ class TestDatasetRetrievalAdditionalHelpers:
attachment_ids=["f1"],
dataset_ids=["d1", "d2"],
app_id="a1",
user_from="web",
user_from="account",
user_id="u1",
)
mock_session.add_all.assert_called()

View File

@ -5,6 +5,7 @@ from typing import Any
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.tool_parameter_cache import ToolParameterCache
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
@ -112,37 +113,38 @@ def test_encrypt_tool_parameters():
def test_decrypt_tool_parameters_cache_hit_and_miss():
manager = _build_manager()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
cache = cache_cls.return_value
cache.get.return_value = {"secret": "cached"}
with (
patch.object(ToolParameterCache, "get", return_value={"secret": "cached"}),
patch.object(ToolParameterCache, "set") as mock_set,
):
assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"}
cache.set.assert_not_called()
mock_set.assert_not_called()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
cache = cache_cls.return_value
cache.get.return_value = None
with patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"):
decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"})
assert decrypted["secret"] == "dec"
cache.set.assert_called_once()
with (
patch.object(ToolParameterCache, "get", return_value=None),
patch.object(ToolParameterCache, "set") as mock_set,
patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"),
):
decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"})
assert decrypted["secret"] == "dec"
mock_set.assert_called_once()
def test_delete_tool_parameters_cache():
manager = _build_manager()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
with patch.object(ToolParameterCache, "delete") as mock_delete:
manager.delete_tool_parameters_cache()
cache_cls.return_value.delete.assert_called_once()
mock_delete.assert_called_once()
def test_configuration_manager_decrypt_suppresses_errors():
manager = _build_manager()
with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls:
cache = cache_cls.return_value
cache.get.return_value = None
with patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")):
decrypted = manager.decrypt_tool_parameters({"secret": "enc"})
with (
patch.object(ToolParameterCache, "get", return_value=None),
patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")),
):
decrypted = manager.decrypt_tool_parameters({"secret": "enc"})
# decryption failure is suppressed, original value is retained.
assert decrypted["secret"] == "enc"

View File

@ -98,7 +98,7 @@ class TestAccountModelValidation:
)
# Assert
assert account.status == "active"
assert account.status == AccountStatus.ACTIVE
def test_account_get_status_method(self):
"""Test the get_status method returns AccountStatus enum."""
@ -106,7 +106,7 @@ class TestAccountModelValidation:
account = Account(
name="Test User",
email="test@example.com",
status="pending",
status=AccountStatus.PENDING,
)
# Act