diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 2ec1a9435a..9182dbb510 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -287,12 +287,10 @@ class ModelProviderModelCredentialApi(Resource): provider=provider, ) else: - # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) - normalized_model_type = args.model_type.to_origin_model_type() available_credentials = model_provider_service.get_provider_model_available_credentials( tenant_id=tenant_id, provider=provider, - model_type=normalized_model_type, + model_type=args.model_type, model=args.model, ) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 8b48aa2660..782897aea9 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ), ) @@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel): ProviderModel.tenant_id == self.tenant_id, ProviderModel.provider_name.in_(provider_names), ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type(), + ProviderModel.model_type == model_type, ) return session.execute(stmt).scalar_one_or_none() @@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ) credential_record = session.execute(stmt).scalar_one_or_none() @@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ProviderModelCredential.credential_name == credential_name, ) if exclude_id: @@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ) credential_record = s.execute(stmt).scalar_one_or_none() original_credentials = ( @@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, - model_type=model_type.to_origin_model_type(), + model_type=model_type, encrypted_config=json.dumps(credentials), credential_name=credential_name, ) @@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, - model_type=model_type.to_origin_model_type(), + model_type=model_type, credential_id=credential.id, is_valid=True, ) @@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: @@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: @@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ) available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) @@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: @@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, provider_name=self.provider.provider, model_name=model, - model_type=model_type.to_origin_model_type(), + model_type=model_type, is_valid=True, credential_id=credential_id, ) @@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, - ProviderModelCredential.model_type == model_type.to_origin_model_type(), + ProviderModelCredential.model_type == model_type, ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: @@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelSetting).where( ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.provider_name.in_(self._get_provider_names()), - ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_type == model_type, ProviderModelSetting.model_name == model, ) return session.execute(stmt).scalars().first() @@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel): model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type, model_name=model, enabled=True, ) @@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel): model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type, model_name=model, enabled=False, ) @@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel): stmt = select(func.count(LoadBalancingModelConfig.id)).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(provider_names), - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type, LoadBalancingModelConfig.model_name == model, ) load_balancing_config_count = session.execute(stmt).scalar() or 0 @@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel): model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type, model_name=model, load_balancing_enabled=True, ) @@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel): model_setting = ProviderModelSetting( tenant_id=self.tenant_id, provider_name=self.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type, model_name=model, load_balancing_enabled=False, ) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index b2a8e9c114..5d536e0e32 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -306,7 +306,7 @@ class ProviderManager: """ stmt = select(TenantDefaultModel).where( TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type(), + TenantDefaultModel.model_type == model_type, ) default_model = db.session.scalar(stmt) @@ -324,7 +324,7 @@ class ProviderManager: default_model = TenantDefaultModel( tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), + model_type=model_type, provider_name=available_model.provider.provider, model_name=available_model.model, ) @@ -391,7 +391,7 @@ class ProviderManager: raise ValueError(f"Model {model} does not exist.") stmt = select(TenantDefaultModel).where( TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type(), + TenantDefaultModel.model_type == model_type, ) default_model = db.session.scalar(stmt) @@ -405,7 +405,7 @@ class ProviderManager: # create default model default_model = TenantDefaultModel( tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), + model_type=model_type, provider_name=provider, model_name=model, ) @@ -822,7 +822,7 @@ class ProviderManager: custom_model_configurations.append( CustomModelConfiguration( model=provider_model_record.model_name, - model_type=ModelType.value_of(provider_model_record.model_type), + model_type=provider_model_record.model_type, credentials=provider_model_credentials, current_credential_id=provider_model_record.credential_id, current_credential_name=provider_model_record.credential_name, @@ -1201,7 +1201,7 @@ class ProviderManager: model_settings.append( ModelSettings( model=provider_model_setting.model_name, - model_type=ModelType.value_of(provider_model_setting.model_type), + model_type=provider_model_setting.model_type, enabled=provider_model_setting.enabled, load_balancing_enabled=provider_model_setting.load_balancing_enabled, load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], diff --git a/api/models/provider.py b/api/models/provider.py index bdcfb7aa0d..8270961b31 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,6 +6,7 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column @@ -131,7 +132,7 @@ class ProviderModel(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - model_type: Mapped[str] = mapped_column(String(40), nullable=False) + model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False) created_at: Mapped[datetime] = mapped_column( @@ -173,7 +174,7 @@ class TenantDefaultModel(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - model_type: Mapped[str] = mapped_column(String(40), nullable=False) + model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -253,7 +254,7 @@ class ProviderModelSetting(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - model_type: Mapped[str] = mapped_column(String(40), nullable=False) + model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) load_balancing_enabled: Mapped[bool] = mapped_column( sa.Boolean, nullable=False, server_default=text("false"), default=False @@ -283,7 +284,7 @@ class LoadBalancingModelConfig(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - model_type: Mapped[str] = mapped_column(String(40), nullable=False) + model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) @@ -348,7 +349,7 @@ class ProviderModelCredential(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - model_type: Mapped[str] = mapped_column(String(40), nullable=False) + model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False) credential_name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 25de411e43..752d3002d9 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -115,7 +115,7 @@ class ModelLoadBalancingService: .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum, LoadBalancingModelConfig.model_name == model, or_( LoadBalancingModelConfig.credential_source_type == credential_source_type, @@ -240,7 +240,7 @@ class ModelLoadBalancingService: .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum, LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -288,7 +288,7 @@ class ModelLoadBalancingService: inherit_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type, model_name=model, name="__inherit__", ) @@ -328,7 +328,7 @@ class ModelLoadBalancingService: select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum, LoadBalancingModelConfig.model_name == model, ) ).all() @@ -368,7 +368,7 @@ class ModelLoadBalancingService: tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, model_name=model, - model_type=model_type_enum.to_origin_model_type(), + model_type=model_type_enum, ) .first() ) @@ -432,7 +432,7 @@ class ModelLoadBalancingService: load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, - model_type=model_type_enum.to_origin_model_type(), + model_type=model_type_enum, model_name=model, name=credential_record.credential_name, encrypted_config=credential_record.encrypted_config, @@ -460,7 +460,7 @@ class ModelLoadBalancingService: load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, - model_type=model_type_enum.to_origin_model_type(), + model_type=model_type_enum, model_name=model, name=name, encrypted_config=json.dumps(credentials), @@ -515,7 +515,7 @@ class ModelLoadBalancingService: .where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum, LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index ca6e7afeab..aca3839135 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -141,7 +141,7 @@ class TestModelLoadBalancingService: tenant_id=tenant_id, provider_name="openai", model_name="gpt-3.5-turbo", - model_type="text-generation", # Use the origin model type that matches the query + model_type="llm", enabled=True, load_balancing_enabled=False, ) @@ -298,7 +298,7 @@ class TestModelLoadBalancingService: tenant_id=tenant.id, provider_name="openai", model_name="gpt-3.5-turbo", - model_type="text-generation", # Use the origin model type that matches the query + model_type="llm", name="config1", encrypted_config='{"api_key": "test_key"}', enabled=True, @@ -417,7 +417,7 @@ class TestModelLoadBalancingService: tenant_id=tenant.id, provider_name="openai", model_name="gpt-3.5-turbo", - model_type="text-generation", # Use the origin model type that matches the query + model_type="llm", name="config1", encrypted_config='{"api_key": "test_key"}', enabled=True, diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 259cb5fdd0..ee26172459 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -48,7 +48,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", enabled=True, load_balancing_enabled=True, ) @@ -61,7 +61,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", name="__inherit__", encrypted_config=None, enabled=True, @@ -70,7 +70,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", name="first", encrypted_config='{"openai_api_key": "fake_key"}', enabled=True, @@ -110,7 +110,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", enabled=True, load_balancing_enabled=True, ) @@ -121,7 +121,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", name="__inherit__", encrypted_config=None, enabled=True, @@ -157,7 +157,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", enabled=True, load_balancing_enabled=False, ) @@ -168,7 +168,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", name="__inherit__", encrypted_config=None, enabled=True, @@ -177,7 +177,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent tenant_id="tenant_id", provider_name="openai", model_name="gpt-4", - model_type="text-generation", + model_type="llm", name="first", encrypted_config='{"openai_api_key": "fake_key"}', enabled=True, @@ -270,7 +270,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc tenant_id="tenant-id", provider_name="openai", model_name="gpt-4", - model_type=ModelType.LLM.to_origin_model_type(), + model_type=ModelType.LLM, ) mock_session = Mock() mock_session.scalar.return_value = existing_default_model @@ -449,7 +449,7 @@ def test_update_default_model_record_updates_existing_record(mocker: MockerFixtu tenant_id="tenant-id", provider_name="anthropic", model_name="claude-3-sonnet", - model_type=ModelType.LLM.to_origin_model_type(), + model_type=ModelType.LLM, ) mock_session = Mock() mock_session.scalar.return_value = existing_default_model @@ -487,7 +487,7 @@ def test_update_default_model_record_creates_record_with_origin_model_type(mocke assert created_default_model.tenant_id == "tenant-id" assert created_default_model.provider_name == "openai" assert created_default_model.model_name == "gpt-4" - assert created_default_model.model_type == ModelType.LLM.to_origin_model_type() + assert created_default_model.model_type == ModelType.LLM mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index b43e79dff5..f85f1ace16 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -317,7 +317,7 @@ def test_init_inherit_config_should_create_and_persist_inherit_configuration( assert inherit_config.tenant_id == "tenant-1" assert inherit_config.provider_name == "openai" assert inherit_config.model_name == "gpt-4o-mini" - assert inherit_config.model_type == "text-generation" + assert inherit_config.model_type == "llm" assert inherit_config.name == "__inherit__" mock_db.session.add.assert_called_once_with(inherit_config) mock_db.session.commit.assert_called_once()