refactor: use EnumText for model_type in provider models (#34300)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
tmimmanuel 2026-03-31 02:31:33 +02:00 committed by GitHub
parent 5897b28355
commit 1344c3b280
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 55 additions and 56 deletions

View File

@ -287,12 +287,10 @@ class ModelProviderModelCredentialApi(Resource):
provider=provider,
)
else:
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.get_provider_model_available_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=normalized_model_type,
model_type=args.model_type,
model=args.model,
)

View File

@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
),
)
@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel):
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
ProviderModel.model_type == model_type.to_origin_model_type(),
ProviderModel.model_type == model_type,
)
return session.execute(stmt).scalar_one_or_none()
@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.credential_name == credential_name,
)
if exclude_id:
@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = s.execute(stmt).scalar_one_or_none()
original_credentials = (
@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
encrypted_config=json.dumps(credentials),
credential_name=credential_name,
)
@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
credential_id=credential.id,
is_valid=True,
)
@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
is_valid=True,
credential_id=credential_id,
)
@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.model_type == model_type,
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_type == model_type,
ProviderModelSetting.model_name == model,
)
return session.execute(stmt).scalars().first()
@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
enabled=True,
)
@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
enabled=False,
)
@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type,
LoadBalancingModelConfig.model_name == model,
)
load_balancing_config_count = session.execute(stmt).scalar() or 0
@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
load_balancing_enabled=True,
)
@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
load_balancing_enabled=False,
)

View File

@ -306,7 +306,7 @@ class ProviderManager:
"""
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
TenantDefaultModel.model_type == model_type,
)
default_model = db.session.scalar(stmt)
@ -324,7 +324,7 @@ class ProviderManager:
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
provider_name=available_model.provider.provider,
model_name=available_model.model,
)
@ -391,7 +391,7 @@ class ProviderManager:
raise ValueError(f"Model {model} does not exist.")
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
TenantDefaultModel.model_type == model_type,
)
default_model = db.session.scalar(stmt)
@ -405,7 +405,7 @@ class ProviderManager:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
provider_name=provider,
model_name=model,
)
@ -822,7 +822,7 @@ class ProviderManager:
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
model_type=ModelType.value_of(provider_model_record.model_type),
model_type=provider_model_record.model_type,
credentials=provider_model_credentials,
current_credential_id=provider_model_record.credential_id,
current_credential_name=provider_model_record.credential_name,
@ -1201,7 +1201,7 @@ class ProviderManager:
model_settings.append(
ModelSettings(
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
model_type=provider_model_setting.model_type,
enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],

View File

@ -6,6 +6,7 @@ from functools import cached_property
from uuid import uuid4
import sqlalchemy as sa
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
@ -131,7 +132,7 @@ class ProviderModel(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
created_at: Mapped[datetime] = mapped_column(
@ -173,7 +174,7 @@ class TenantDefaultModel(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
@ -253,7 +254,7 @@ class ProviderModelSetting(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
load_balancing_enabled: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=text("false"), default=False
@ -283,7 +284,7 @@ class LoadBalancingModelConfig(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
@ -348,7 +349,7 @@ class ProviderModelCredential(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
created_at: Mapped[datetime] = mapped_column(

View File

@ -115,7 +115,7 @@ class ModelLoadBalancingService:
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
@ -240,7 +240,7 @@ class ModelLoadBalancingService:
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)
@ -288,7 +288,7 @@ class ModelLoadBalancingService:
inherit_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider,
model_type=model_type.to_origin_model_type(),
model_type=model_type,
model_name=model,
name="__inherit__",
)
@ -328,7 +328,7 @@ class ModelLoadBalancingService:
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
)
).all()
@ -368,7 +368,7 @@ class ModelLoadBalancingService:
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_name=model,
model_type=model_type_enum.to_origin_model_type(),
model_type=model_type_enum,
)
.first()
)
@ -432,7 +432,7 @@ class ModelLoadBalancingService:
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_type=model_type_enum,
model_name=model,
name=credential_record.credential_name,
encrypted_config=credential_record.encrypted_config,
@ -460,7 +460,7 @@ class ModelLoadBalancingService:
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,
model_type=model_type_enum.to_origin_model_type(),
model_type=model_type_enum,
model_name=model,
name=name,
encrypted_config=json.dumps(credentials),
@ -515,7 +515,7 @@ class ModelLoadBalancingService:
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_type == model_type_enum,
LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id,
)

View File

@ -141,7 +141,7 @@ class TestModelLoadBalancingService:
tenant_id=tenant_id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
model_type="llm",
enabled=True,
load_balancing_enabled=False,
)
@ -298,7 +298,7 @@ class TestModelLoadBalancingService:
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
model_type="llm",
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,
@ -417,7 +417,7 @@ class TestModelLoadBalancingService:
tenant_id=tenant.id,
provider_name="openai",
model_name="gpt-3.5-turbo",
model_type="text-generation", # Use the origin model type that matches the query
model_type="llm",
name="config1",
encrypted_config='{"api_key": "test_key"}',
enabled=True,

View File

@ -48,7 +48,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
enabled=True,
load_balancing_enabled=True,
)
@ -61,7 +61,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
name="__inherit__",
encrypted_config=None,
enabled=True,
@ -70,7 +70,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
@ -110,7 +110,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
enabled=True,
load_balancing_enabled=True,
)
@ -121,7 +121,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
name="__inherit__",
encrypted_config=None,
enabled=True,
@ -157,7 +157,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
enabled=True,
load_balancing_enabled=False,
)
@ -168,7 +168,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
name="__inherit__",
encrypted_config=None,
enabled=True,
@ -177,7 +177,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
model_type="llm",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
@ -270,7 +270,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
tenant_id="tenant-id",
provider_name="openai",
model_name="gpt-4",
model_type=ModelType.LLM.to_origin_model_type(),
model_type=ModelType.LLM,
)
mock_session = Mock()
mock_session.scalar.return_value = existing_default_model
@ -449,7 +449,7 @@ def test_update_default_model_record_updates_existing_record(mocker: MockerFixtu
tenant_id="tenant-id",
provider_name="anthropic",
model_name="claude-3-sonnet",
model_type=ModelType.LLM.to_origin_model_type(),
model_type=ModelType.LLM,
)
mock_session = Mock()
mock_session.scalar.return_value = existing_default_model
@ -487,7 +487,7 @@ def test_update_default_model_record_creates_record_with_origin_model_type(mocke
assert created_default_model.tenant_id == "tenant-id"
assert created_default_model.provider_name == "openai"
assert created_default_model.model_name == "gpt-4"
assert created_default_model.model_type == ModelType.LLM.to_origin_model_type()
assert created_default_model.model_type == ModelType.LLM
mock_session.commit.assert_called_once()

View File

@ -317,7 +317,7 @@ def test_init_inherit_config_should_create_and_persist_inherit_configuration(
assert inherit_config.tenant_id == "tenant-1"
assert inherit_config.provider_name == "openai"
assert inherit_config.model_name == "gpt-4o-mini"
assert inherit_config.model_type == "text-generation"
assert inherit_config.model_type == "llm"
assert inherit_config.name == "__inherit__"
mock_db.session.add.assert_called_once_with(inherit_config)
mock_db.session.commit.assert_called_once()