diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 708df62642..f497cd4bbb 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -43,6 +43,7 @@ from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode +from models.account import AccountStatus from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -231,7 +232,7 @@ class AccountInitApi(Resource): account.interface_language = args.interface_language account.timezone = args.timezone account.interface_theme = "light" - account.status = "active" + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index d0279349ca..b054409681 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -12,6 +12,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import CreatorUserRole _logger = logging.getLogger(__name__) @@ -38,7 +39,9 @@ class DatasetIndexToolCallbackHandler: source="app", source_app_id=self._app_id, created_by_role=( - "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + CreatorUserRole.ACCOUNT + if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER ), created_by=self._user_id, ) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 33782e7949..9ac753240b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -628,10 +628,10 @@ class TraceTask: if not message_data: return {} conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) - conversation_mode = db.session.scalars(conversation_mode_stmt).all() - if not conversation_mode or len(conversation_mode) == 0: + conversation_modes = db.session.scalars(conversation_mode_stmt).all() + if not conversation_modes or len(conversation_modes) == 0: return {} - conversation_mode = conversation_mode[0] + conversation_mode = conversation_modes[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index f82c3a846b..c29a463bb6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -627,7 +627,7 @@ class ProviderManager: tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, quota_type=quota.quota_type, quota_limit=0, # type: ignore quota_used=0, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8243170c62..fcd3cceb59 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -83,6 +83,7 @@ from models.dataset import ( ) from models.dataset import Document as DatasetDocument from models.dataset import Document as DocumentModel +from models.enums import CreatorUserRole from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService @@ -1009,7 +1010,7 @@ class DatasetRetrieval: content=json.dumps(contents), source="app", source_app_id=app_id, - created_by_role=user_from, + created_by_role=CreatorUserRole(user_from), created_by=user_id, ) dataset_queries.append(dataset_query) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 770df8b050..55e96515ac 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -146,7 +146,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # No sequence number generation needed anymore - db_model.type = domain_model.workflow_type + from models.workflow import WorkflowType as ModelWorkflowType + + db_model.type = ModelWorkflowType(domain_model.workflow_type.value) db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 7ee4638e77..a94d75ec76 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -17,7 +17,8 @@ from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from models.workflow import WorkflowNodeExecutionModel +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository logger = logging.getLogger(__name__) @@ -47,12 +48,28 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.triggered_from = data.get("triggered_from") or "" + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowNodeExecutionTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to WORKFLOW_RUN", triggered_from_val) + model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" model.status = data.get("status") or "running" # Default status if missing model.title = data.get("title") or "" - model.created_by_role = data.get("created_by_role") or "" + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.index = safe_int(data.get("index", 0)) diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 14382ed876..bdfc81bd1c 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -22,12 +22,13 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker +from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun, WorkflowType from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.types import ( AverageInteractionStats, @@ -59,11 +60,37 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.type = data.get("type") or "" - model.triggered_from = data.get("triggered_from") or "" + type_val = data.get("type") + try: + model.type = WorkflowType(str(type_val)) if type_val else WorkflowType.WORKFLOW + except ValueError: + logger.warning("Invalid type value: %s, falling back to WORKFLOW", type_val) + model.type = WorkflowType.WORKFLOW + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowRunTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowRunTriggeredFrom.APP_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to APP_RUN", triggered_from_val) + model.triggered_from = WorkflowRunTriggeredFrom.APP_RUN model.version = data.get("version") or "" - model.status = data.get("status") or "running" # Default status if missing - model.created_by_role = data.get("created_by_role") or "" + status_val = data.get("status") + try: + model.status = WorkflowExecutionStatus(str(status_val)) if status_val else WorkflowExecutionStatus.RUNNING + except ValueError: + logger.warning("Invalid status value: %s, falling back to RUNNING", status_val) + model.status = WorkflowExecutionStatus.RUNNING + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.total_tokens = safe_int(data.get("total_tokens", 0)) diff --git a/api/models/account.py b/api/models/account.py index f7a9c20026..c354e4c0a5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,12 +8,12 @@ from uuid import uuid4 import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column, validates +from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID class TenantAccountRole(enum.StrEnum): @@ -104,7 +104,9 @@ class Account(UserMixin, TypeBase): last_active_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False ) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'"), default="active") + status: Mapped[AccountStatus] = mapped_column( + EnumText(AccountStatus, length=16), server_default=sa.text("'active'"), default=AccountStatus.ACTIVE + ) initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -116,12 +118,6 @@ class Account(UserMixin, TypeBase): role: TenantAccountRole | None = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False) - @validates("status") - def _normalize_status(self, _key: str, value: str | AccountStatus) -> str: - if isinstance(value, AccountStatus): - return value.value - return value - @property def is_password_set(self): return self.password is not None @@ -177,8 +173,7 @@ class Account(UserMixin, TypeBase): return self.role def get_status(self) -> AccountStatus: - status_str = self.status - return AccountStatus(status_str) + return self.status @classmethod def get_by_openid(cls, provider: str, open_id: str): @@ -249,7 +244,9 @@ class Tenant(TypeBase): name: Mapped[str] = mapped_column(String(255)) encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None) plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic") - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'"), default="normal") + status: Mapped[TenantStatus] = mapped_column( + EnumText(TenantStatus, length=255), server_default=sa.text("'normal'"), default=TenantStatus.NORMAL + ) custom_config: Mapped[str | None] = mapped_column(LongText, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -291,7 +288,9 @@ class TenantAccountJoin(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False) - role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal") + role: Mapped[TenantAccountRole] = mapped_column( + EnumText(TenantAccountRole, length=16), server_default="normal", default=TenantAccountRole.NORMAL + ) invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False diff --git a/api/models/dataset.py b/api/models/dataset.py index 4ef39fcde1..b3fa11a58c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -30,8 +30,9 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .base import Base, TypeBase from .engine import db +from .enums import CreatorUserRole from .model import App, Tag, TagBinding, UploadFile -from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index +from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index logger = logging.getLogger(__name__) @@ -59,7 +60,11 @@ class Dataset(Base): name: Mapped[str] = mapped_column(String(255)) description = mapped_column(LongText, nullable=True) provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'")) - permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'")) + permission: Mapped[DatasetPermissionEnum] = mapped_column( + EnumText(DatasetPermissionEnum, length=255), + server_default=sa.text("'only_me'"), + default=DatasetPermissionEnum.ONLY_ME, + ) data_source_type = mapped_column(String(255)) indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(LongText, nullable=True) @@ -1003,7 +1008,7 @@ class DatasetQuery(TypeBase): content: Mapped[str] = mapped_column(LongText, nullable=False) source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False diff --git a/api/models/model.py b/api/models/model.py index ed0614c195..5fd80c5757 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -29,9 +29,9 @@ from libs.uuid_utils import uuidv7 from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db -from .enums import CreatorUserRole +from .enums import CreatorUserRole, MessageStatus from .provider_ids import GenericProviderID -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from .workflow import Workflow @@ -337,8 +337,8 @@ class App(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) - mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255)) icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) @@ -1000,7 +1000,7 @@ class Conversation(Base): model_provider = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(LongText) model_id = mapped_column(String(255), nullable=True) - mode: Mapped[str] = mapped_column(String(255)) + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(LongText) _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) @@ -1351,7 +1351,12 @@ class Message(Base): provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7)) currency: Mapped[str] = mapped_column(String(255), nullable=False) - status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[MessageStatus] = mapped_column( + EnumText(MessageStatus, length=255), + nullable=False, + server_default=sa.text("'normal'"), + default=MessageStatus.NORMAL, + ) error: Mapped[str | None] = mapped_column(LongText) message_metadata: Mapped[str | None] = mapped_column(LongText) invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) @@ -1364,7 +1369,7 @@ class Message(Base): ) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) - app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True) + app_mode: Mapped[AppMode | None] = mapped_column(EnumText(AppMode, length=255), nullable=True) @property def inputs(self) -> dict[str, Any]: @@ -1767,7 +1772,7 @@ class MessageFile(TypeBase): message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None) url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) @@ -2015,7 +2020,7 @@ class Site(Base): id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - icon_type = mapped_column(String(255), nullable=True) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255), nullable=True) icon = mapped_column(String(255)) icon_background = mapped_column(String(255)) description = mapped_column(LongText) @@ -2110,7 +2115,12 @@ class UploadFile(Base): # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. # Its value is derived from the `CreatorUserRole` enumeration. - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'account'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), + nullable=False, + server_default=sa.text("'account'"), + default=CreatorUserRole.ACCOUNT, + ) # The `created_by` field stores the ID of the entity that created this upload file. # @@ -2163,7 +2173,7 @@ class UploadFile(Base): self.size = size self.extension = extension self.mime_type = mime_type - self.created_by_role = created_by_role.value + self.created_by_role = created_by_role self.created_by = created_by self.created_at = created_at self.used = used @@ -2226,7 +2236,7 @@ class MessageAgentThought(TypeBase): ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) diff --git a/api/models/provider.py b/api/models/provider.py index 6175a3ae88..18a0fe92c8 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID class ProviderType(StrEnum): @@ -69,8 +69,8 @@ class Provider(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) - provider_type: Mapped[str] = mapped_column( - String(40), nullable=False, server_default=text("'custom'"), default="custom" + provider_type: Mapped[ProviderType] = mapped_column( + EnumText(ProviderType, length=40), nullable=False, server_default=text("'custom'"), default=ProviderType.CUSTOM ) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False) last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False) diff --git a/api/models/trigger.py b/api/models/trigger.py index 209345eb84..43d7fc5b24 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -227,7 +227,7 @@ class WorkflowTriggerLog(TypeBase): queue_name: Mapped[str] = mapped_column(String(100), nullable=False) celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(String(255), nullable=False) retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) diff --git a/api/models/web.py b/api/models/web.py index 5f6a7b40bf..a1cc11c375 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -2,13 +2,14 @@ from datetime import datetime from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, String, func +from sqlalchemy import DateTime, func from sqlalchemy.orm import Mapped, mapped_column from .base import TypeBase from .engine import db +from .enums import CreatorUserRole from .model import Message -from .types import StringUUID +from .types import EnumText, StringUUID class SavedMessage(TypeBase): @@ -24,7 +25,9 @@ class SavedMessage(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'") + ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, @@ -50,8 +53,8 @@ class PinnedConversation(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) - created_by_role: Mapped[str] = mapped_column( - String(255), + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'"), ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 21b899eeda..8c62292079 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -53,7 +53,7 @@ from libs import helper from .account import Account from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db -from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType +from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) @@ -141,7 +141,7 @@ class Workflow(Base): # bug id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255), nullable=False) version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(String(255), default="", server_default="") marked_comment: Mapped[str] = mapped_column(String(255), default="", server_default="") @@ -188,7 +188,7 @@ class Workflow(Base): # bug workflow.id = str(uuid4()) workflow.tenant_id = tenant_id workflow.app_id = app_id - workflow.type = type + workflow.type = WorkflowType(type) workflow.version = version workflow.graph = graph workflow.features = features @@ -608,8 +608,8 @@ class WorkflowRun(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(String(255)) - triggered_from: Mapped[str] = mapped_column(String(255)) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255)) + triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(EnumText(WorkflowRunTriggeredFrom, length=255)) version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) @@ -830,7 +830,9 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[WorkflowNodeExecutionTriggeredFrom] = mapped_column( + EnumText(WorkflowNodeExecutionTriggeredFrom, length=255) + ) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(sa.Integer) predecessor_node_id: Mapped[str | None] = mapped_column(String(255)) @@ -846,7 +848,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(String(255)) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -1130,7 +1132,7 @@ class WorkflowAppLog(TypeBase): workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from: Mapped[str] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1204,7 +1206,7 @@ class WorkflowArchiveLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) @@ -1213,7 +1215,9 @@ class WorkflowArchiveLog(TypeBase): run_version: Mapped[str] = mapped_column(String(255), nullable=False) run_status: Mapped[str] = mapped_column(String(255), nullable=False) - run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( + EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False + ) run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) diff --git a/api/services/account_service.py b/api/services/account_service.py index f0eac2a522..bd520f54cf 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1089,9 +1089,9 @@ class TenantService: ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: - ta.role = role + ta.role = TenantAccountRole(role) else: - ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole(role)) db.session.add(ta) db.session.commit() @@ -1319,10 +1319,10 @@ class TenantService: db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() ) if current_owner_join: - current_owner_join.role = "admin" + current_owner_join.role = TenantAccountRole.ADMIN # Update the role of the target member - target_member_join.role = new_role + target_member_join.role = TenantAccountRole(new_role) db.session.commit() @staticmethod diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 06f4ccb90e..49ca273442 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -429,17 +429,18 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") + resolved_icon_type: IconType if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]: - icon_type = icon_type_value + resolved_icon_type = IconType(icon_type_value) else: - icon_type = IconType.EMOJI + resolved_icon_type = IconType.EMOJI icon = icon or str(app_data.get("icon", "")) if app: # Update existing app app.name = name or app_data.get("name", app.name) app.description = description or app_data.get("description", app.description) - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id @@ -452,10 +453,10 @@ class AppDslService: app = App() app.id = str(uuid4()) app.tenant_id = account.current_tenant_id - app.mode = app_mode.value + app.mode = app_mode app.name = name or app_data.get("name", "") app.description = description or app_data.get("description", "") - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF") app.enable_site = True @@ -549,7 +550,7 @@ class AppDslService: "kind": "app", "app": { "name": app_model.name, - "mode": app_model.mode, + "mode": app_model.mode.value if isinstance(app_model.mode, AppMode) else app_model.mode, "icon": app_model.icon if app_model.icon_type == "image" else "🤖", "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, "description": app_model.description, diff --git a/api/services/app_service.py b/api/services/app_service.py index aba8954f1a..b5e893c5b5 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -19,7 +19,7 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account -from models.model import App, AppMode, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig, IconType, Site from models.tools import ApiToolProvider from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -254,7 +254,7 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = args["icon_type"] + app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3a7d483a9d..c527c71d7b 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -254,7 +254,7 @@ class DatasetService: dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model_name if embedding_model else None dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None - dataset.permission = permission or DatasetPermissionEnum.ONLY_ME + dataset.permission = DatasetPermissionEnum(permission) if permission else DatasetPermissionEnum.ONLY_ME dataset.provider = provider if summary_index_setting is not None: dataset.summary_index_setting = summary_index_setting diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index c00c76a826..d85b290534 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -13,6 +13,7 @@ from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery +from models.enums import CreatorUserRole logger = logging.getLogger(__name__) @@ -98,7 +99,7 @@ class HitTestingService: content=json.dumps(dataset_queries), source="hit_testing", source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) db.session.add(dataset_query) @@ -138,7 +139,7 @@ class HitTestingService: content=query, source="hit_testing", source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 4dd6c8107b..d0f4f27968 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -3,6 +3,7 @@ from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import SavedMessage from services.message_service import MessageService @@ -54,7 +55,7 @@ class SavedMessageService: saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 560aec2330..e028e3e5e3 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import PinnedConversation from services.conversation_service import ConversationService @@ -84,7 +85,7 @@ class WebConversationService: pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 0153046acc..3acbc93678 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -24,7 +24,7 @@ from events.app_event import app_was_created from extensions.ext_database import db from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, IconType from models.workflow import Workflow, WorkflowType @@ -72,7 +72,7 @@ class WorkflowConverter: new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW - new_app.icon_type = icon_type or app_model.icon_type + new_app.icon_type = IconType(icon_type) if icon_type else app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background new_app.enable_site = app_model.enable_site diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index d06b8c980b..e7f4e37c75 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -164,7 +164,7 @@ def _record_trigger_failure_log( elapsed_time=0.0, total_tokens=0, total_steps=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, created_at=now, finished_at=now, @@ -179,7 +179,7 @@ def _record_trigger_failure_log( workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, ) session.add(workflow_app_log) @@ -212,7 +212,7 @@ def _record_trigger_failure_log( error=error_message, queue_name=queue_name, retry_count=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, triggered_at=now, finished_at=now, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index db8721e90b..f41118e592 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -94,13 +94,15 @@ def _create_workflow_run_from_execution( workflow_run.tenant_id = tenant_id workflow_run.app_id = app_id workflow_run.workflow_id = execution.workflow_id - workflow_run.type = execution.workflow_type.value - workflow_run.triggered_from = triggered_from.value + from models.workflow import WorkflowType as ModelWorkflowType + + workflow_run.type = ModelWorkflowType(execution.workflow_type.value) + workflow_run.triggered_from = triggered_from workflow_run.version = execution.workflow_version json_converter = WorkflowRuntimeTypeConverter() workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph)) workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) @@ -108,7 +110,7 @@ def _create_workflow_run_from_execution( workflow_run.elapsed_time = execution.elapsed_time workflow_run.total_tokens = execution.total_tokens workflow_run.total_steps = execution.total_steps - workflow_run.created_by_role = creator_user_role.value + workflow_run.created_by_role = creator_user_role workflow_run.created_by = creator_user_id workflow_run.created_at = execution.started_at workflow_run.finished_at = execution.finished_at @@ -121,7 +123,7 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo Update a WorkflowRun database model from a WorkflowExecution domain entity. """ json_converter = WorkflowRuntimeTypeConverter() - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 3f607dc55e..eaafbf99e3 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -98,7 +98,7 @@ def _create_node_execution_from_domain( node_execution.tenant_id = tenant_id node_execution.app_id = app_id node_execution.workflow_id = execution.workflow_id - node_execution.triggered_from = triggered_from.value + node_execution.triggered_from = triggered_from node_execution.workflow_run_id = execution.workflow_execution_id node_execution.index = execution.index node_execution.predecessor_node_id = execution.predecessor_node_id @@ -128,7 +128,7 @@ def _create_node_execution_from_domain( node_execution.status = execution.status.value node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time - node_execution.created_by_role = creator_user_role.value + node_execution.created_by_role = creator_user_role node_execution.created_by = creator_user_id node_execution.created_at = execution.created_at node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index 498ac56d5d..afb6938baa 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -165,7 +165,7 @@ class TestChatMessageApiPermissions: agent_thoughts=[], message_files=[], message_metadata_dict={}, - status="success", + status="normal", error="", parent_message_id=None, ) diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 9354a3ac35..cc9596d15f 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -3331,7 +3331,7 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "suspended" + tenant.status = "archive" db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5155d50b0e..5b1a4790f5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -2,6 +2,7 @@ import uuid from unittest.mock import ANY, MagicMock, patch import pytest +import sqlalchemy as sa from faker import Faker from sqlalchemy.orm import Session @@ -492,20 +493,20 @@ class TestAppGenerateService: ) # Manually set invalid mode after creation + # With EnumText, invalid values are rejected at the DB level during flush, + # raising StatementError wrapping ValueError app.mode = "invalid_mode" # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} - # Execute the method under test and expect ValueError - with pytest.raises(ValueError) as exc_info: + # Execute the method under test and expect either ValueError (direct) or + # StatementError (from EnumText validation during autoflush) + with pytest.raises((ValueError, sa.exc.StatementError)): AppGenerateService.generate( app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True ) - # Verify error message - assert "Invalid app mode" in str(exc_info.value) - def test_generate_with_workflow_id_format_error( self, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index cc403ef5a2..dd743d46c2 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -163,7 +163,7 @@ class TestSavedMessageService: answer_unit_price=0.002, total_price=0.003, currency="USD", - status="success", + status="normal", ) db_session_with_containers.add(message) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index bfb23bac68..d8b43efeba 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -62,7 +62,7 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() @@ -1090,20 +1090,19 @@ class TestWorkflowService: This test ensures that the service correctly handles feature validation for unsupported app modes, preventing invalid operations. + With EnumText, invalid values are rejected at the DB level during flush, + raising StatementError wrapping ValueError. """ # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) app.mode = "invalid_mode" # Invalid mode - db_session_with_containers.commit() + # Act & Assert - EnumText validation rejects invalid values at DB flush + import sqlalchemy as sa - workflow_service = WorkflowService() - features = {"test": "value"} - - # Act & Assert - with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): - workflow_service.validate_features_structure(app_model=app, features=features) + with pytest.raises((ValueError, sa.exc.StatementError)): + db_session_with_containers.commit() def test_update_workflow_success(self, db_session_with_containers: Session): """ diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 8eb881258a..41d9fc8a29 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -110,7 +110,7 @@ class TestCleanDatasetTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index bc0ed3bd2b..69ed5b632d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -48,7 +48,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 8f47b48ae2..6f7d2c28b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -65,7 +65,7 @@ class TestDisableSegmentsFromIndexTask: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 3cdec70df7..c0ddc27286 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -118,7 +118,7 @@ class TestSendEmailCodeLoginMailTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index c3a6522e6d..6b5c304884 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -48,7 +48,7 @@ def make_message(): msg.query = "hello" msg.re_sign_file_url_answer = "" msg.user_feedback = MagicMock(rating=None) - msg.status = "success" + msg.status = "normal" msg.error = None return msg diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 1c096bfbcf..2bb425cdba 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -137,7 +137,7 @@ def test_message_list_mapping(app: Flask) -> None: {"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"}, message_file_obj, ], - status="success", + status="normal", error=None, message_metadata_dict={"meta": "value"}, extra_contents=[ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b90c4935af..de3ccc4518 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -3730,7 +3730,7 @@ class TestDatasetRetrievalAdditionalHelpers: attachment_ids=None, dataset_ids=["d1"], app_id="a1", - user_from="web", + user_from="account", user_id="u1", ) mock_session.add_all.assert_not_called() @@ -3740,7 +3740,7 @@ class TestDatasetRetrievalAdditionalHelpers: attachment_ids=["f1"], dataset_ids=["d1", "d2"], app_id="a1", - user_from="web", + user_from="account", user_id="u1", ) mock_session.add_all.assert_called() diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py index 4d59affb99..5ceaa08893 100644 --- a/api/tests/unit_tests/core/tools/utils/test_configuration.py +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -5,6 +5,7 @@ from typing import Any from unittest.mock import patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.tool_parameter_cache import ToolParameterCache from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject @@ -112,37 +113,38 @@ def test_encrypt_tool_parameters(): def test_decrypt_tool_parameters_cache_hit_and_miss(): manager = _build_manager() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: - cache = cache_cls.return_value - cache.get.return_value = {"secret": "cached"} + with ( + patch.object(ToolParameterCache, "get", return_value={"secret": "cached"}), + patch.object(ToolParameterCache, "set") as mock_set, + ): assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"} - cache.set.assert_not_called() + mock_set.assert_not_called() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: - cache = cache_cls.return_value - cache.get.return_value = None - with patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"): - decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"}) - - assert decrypted["secret"] == "dec" - cache.set.assert_called_once() + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch.object(ToolParameterCache, "set") as mock_set, + patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"}) + assert decrypted["secret"] == "dec" + mock_set.assert_called_once() def test_delete_tool_parameters_cache(): manager = _build_manager() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: + with patch.object(ToolParameterCache, "delete") as mock_delete: manager.delete_tool_parameters_cache() - cache_cls.return_value.delete.assert_called_once() + mock_delete.assert_called_once() def test_configuration_manager_decrypt_suppresses_errors(): manager = _build_manager() - with patch("core.tools.utils.configuration.ToolParameterCache") as cache_cls: - cache = cache_cls.return_value - cache.get.return_value = None - with patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")): - decrypted = manager.decrypt_tool_parameters({"secret": "enc"}) + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc"}) # decryption failure is suppressed, original value is retained. assert decrypted["secret"] == "enc" diff --git a/api/tests/unit_tests/models/test_account_models.py b/api/tests/unit_tests/models/test_account_models.py index cc311d447f..1726fc2e8b 100644 --- a/api/tests/unit_tests/models/test_account_models.py +++ b/api/tests/unit_tests/models/test_account_models.py @@ -98,7 +98,7 @@ class TestAccountModelValidation: ) # Assert - assert account.status == "active" + assert account.status == AccountStatus.ACTIVE def test_account_get_status_method(self): """Test the get_status method returns AccountStatus enum.""" @@ -106,7 +106,7 @@ class TestAccountModelValidation: account = Account( name="Test User", email="test@example.com", - status="pending", + status=AccountStatus.PENDING, ) # Act