diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py index da306fbc9d..5dfef6bf6a 100644 --- a/api/controllers/console/explore/banner.py +++ b/api/controllers/console/explore/banner.py @@ -4,6 +4,7 @@ from flask_restx import Resource from controllers.console import api from controllers.console.explore.wraps import explore_banner_enabled from extensions.ext_database import db +from models.enums import BannerStatus from models.model import ExporleBanner @@ -16,7 +17,7 @@ class BannerApi(Resource): language = request.args.get("language", "en-US") # Build base query for enabled banners - base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled") + base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED) # Try to get banners in the requested language banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index c6a270e470..a9f2300ba2 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1422,12 +1422,12 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = s.execute(stmt).scalars().first() if preferred_model_provider: - preferred_model_provider.preferred_provider_type = provider_type.value + preferred_model_provider.preferred_provider_type = provider_type else: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value, + preferred_provider_type=provider_type, ) s.add(preferred_model_provider) s.commit() diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index ed34922346..3c3fbd6dd2 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -195,7 +195,7 @@ class ProviderManager: preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) if preferred_provider_type_record: - preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) + preferred_provider_type = preferred_provider_type_record.preferred_provider_type elif dify_config.EDITION == "CLOUD" and system_configuration.enabled: preferred_provider_type = ProviderType.SYSTEM elif custom_configuration.provider or custom_configuration.models: diff --git a/api/models/model.py b/api/models/model.py index fe70fcd401..ff69d9d3a2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -29,7 +29,15 @@ from libs.uuid_utils import uuidv7 from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db -from .enums import AppMCPServerStatus, AppStatus, ConversationStatus, CreatorUserRole, MessageStatus +from .enums import ( + AppMCPServerStatus, + AppStatus, + BannerStatus, + ConversationStatus, + CreatorUserRole, + MessageChainType, + MessageStatus, +) from .provider_ids import GenericProviderID from .types import EnumText, LongText, StringUUID @@ -925,8 +933,11 @@ class ExporleBanner(TypeBase): content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) - status: Mapped[str] = mapped_column( - sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" + status: Mapped[BannerStatus] = mapped_column( + EnumText(BannerStatus, length=255), + nullable=False, + server_default=sa.text("'enabled'::character varying"), + default=BannerStatus.ENABLED, ) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -2206,7 +2217,7 @@ class MessageChain(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False) input: Mapped[str | None] = mapped_column(LongText, nullable=True) output: Mapped[str | None] = mapped_column(LongText, nullable=True) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/provider.py b/api/models/provider.py index 4e114bb034..afeee20b1e 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -210,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index ef1f31d36b..7b5157fa61 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.enums import DataSourceType +from models.enums import DataSourceType, MessageChainType from models.model import ( App, AppAnnotationHitHistory, @@ -236,7 +236,7 @@ class TestMessagesCleanServiceIntegration: # MessageChain chain = MessageChain( message_id=message.id, - type="system", + type=MessageChainType.SYSTEM, input=json.dumps({"test": "input"}), output=json.dumps({"test": "output"}), ) diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py index 0606219356..4414f1eb5f 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_banner.py +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -2,6 +2,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import controllers.console.explore.banner as banner_module +from models.enums import BannerStatus def unwrap(func): @@ -20,7 +21,7 @@ class TestBannerApi: banner.content = {"text": "hello"} banner.link = "https://example.com" banner.sort = 1 - banner.status = "enabled" + banner.status = BannerStatus.ENABLED banner.created_at = datetime(2024, 1, 1) query = MagicMock() @@ -54,7 +55,7 @@ class TestBannerApi: banner.content = {"text": "fallback"} banner.link = None banner.sort = 1 - banner.status = "enabled" + banner.status = BannerStatus.ENABLED banner.created_at = None query = MagicMock() diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 75473fc89a..95d58757f1 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -410,7 +410,7 @@ def test_switch_preferred_provider_type_updates_existing_record_with_session() - configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session) - assert existing_record.preferred_provider_type == ProviderType.SYSTEM.value + assert existing_record.preferred_provider_type == ProviderType.SYSTEM session.commit.assert_called_once()