From d4c2003450e75bdd83fcc6581e60c8bf167b7422 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Tue, 19 Aug 2025 10:57:04 +0800 Subject: [PATCH] fix mypy --- api/core/entities/provider_configuration.py | 47 +++++++++++--------- api/services/model_load_balancing_service.py | 2 + api/services/model_provider_service.py | 4 +- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 74fc69b955..f6be2d75b2 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from typing import Optional +from typing import Optional, Any from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import func, select @@ -270,7 +270,7 @@ class ProviderConfiguration(BaseModel): return self._get_specific_provider_credential(credential_id) # Default behavior: return current active provider credentials - credentials = self.custom_configuration.provider.credentials + credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {} return self.obfuscated_credentials( credentials=credentials, @@ -290,7 +290,7 @@ class ProviderConfiguration(BaseModel): :return: """ - def _validate(s: Session) -> tuple[Provider | None, dict]: + def _validate(s: Session) -> dict: # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas @@ -452,7 +452,11 @@ class ProviderConfiguration(BaseModel): raise def _update_load_balancing_configs_with_credential( - self, credential_id: str, credential_record: dict, credential_source: str, session: Session + self, + credential_id: str, + credential_record: ProviderCredential | ProviderModelCredential, + credential_source: str, + session: Session, ) -> None: """ Update load balancing configurations that reference the given credential_id. @@ -513,13 +517,13 @@ class ProviderConfiguration(BaseModel): raise ValueError("Credential record not found.") # Check if this credential is used in load balancing configs - stmt = select(LoadBalancingModelConfig).where( + lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "provider", ) - lb_configs_using_credential = session.execute(stmt).scalars().all() + lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: for lb_config in lb_configs_using_credential: lb_credentials_cache = ProviderCredentialsCache( @@ -541,11 +545,11 @@ class ProviderConfiguration(BaseModel): # Check available credentials count BEFORE deleting # if this is the last credential, we need to delete the provider record - stmt = select(func.count(ProviderCredential.id)).where( + count_stmt = select(func.count(ProviderCredential.id)).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name == self.provider.provider, ) - available_credentials_count = session.execute(stmt).scalar() + available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) if provider_record and available_credentials_count <= 1: @@ -644,14 +648,12 @@ class ProviderConfiguration(BaseModel): ) with Session(db.engine) as session: - model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) - stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == model_record.provider_name, - ProviderModelCredential.model_name == model_record.model_name, - ProviderModelCredential.model_type == model_record.model_type, + ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.model_name == model, + ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) credential = session.execute(stmt).scalar_one_or_none() @@ -722,6 +724,7 @@ class ProviderConfiguration(BaseModel): if self.provider.model_credential_schema else [], ) + return None def validate_custom_model_credentials( self, @@ -730,7 +733,7 @@ class ProviderConfiguration(BaseModel): credentials: dict, credential_id: str = "", session: Session | None = None, - ) -> tuple[ProviderModel | None, dict]: + ) -> dict: """ Validate custom model credentials. @@ -741,7 +744,7 @@ class ProviderConfiguration(BaseModel): :return: """ - def _validate(s: Session) -> tuple[ProviderModel | None, dict]: + def _validate(s: Session) -> dict: # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas @@ -938,13 +941,13 @@ class ProviderConfiguration(BaseModel): if not credential_record: raise ValueError("Credential record not found.") - stmt = select(LoadBalancingModelConfig).where( + lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "custom_model", ) - lb_configs_using_credential = session.execute(stmt).scalars().all() + lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: for lb_config in lb_configs_using_credential: @@ -966,13 +969,13 @@ class ProviderConfiguration(BaseModel): # Check available credentials count BEFORE deleting # if this is the last credential, we need to delete the custom model record - stmt = select(func.count(ProviderModelCredential.id)).where( + count_stmt = select(func.count(ProviderModelCredential.id)).where( ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) - available_credentials_count = session.execute(stmt).scalar() + available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) if provider_model_record and available_credentials_count <= 1: @@ -1161,7 +1164,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) - load_balancing_config_count = session.execute(stmt).scalar() + load_balancing_config_count = session.execute(stmt).scalar() or 0 if load_balancing_config_count <= 1: raise ValueError("Model load balancing configuration must be more than 1.") @@ -1264,8 +1267,8 @@ class ProviderConfiguration(BaseModel): provider_name=self.provider.provider, preferred_provider_type=provider_type.value, ) - session.add(preferred_model_provider) - session.commit() + s.add(preferred_model_provider) + s.commit() if session: return _switch(session) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 4b677203ba..b9a7d580fa 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -333,6 +333,7 @@ class ModelLoadBalancingService: enabled = config.get("enabled") if credential_id: + credential_record: ProviderCredential | ProviderModelCredential | None = None if config_from == "predefined-model": credential_record = ( db.session.query(ProviderCredential) @@ -407,6 +408,7 @@ class ModelLoadBalancingService: if credential_id: credential_source = "provider" if config_from == "predefined-model" else "custom_model" + assert credential_record is not None load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index e330d2852f..62cf9884e4 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -138,7 +138,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_provider_credential(credential_id=credential_id) + return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: """ @@ -226,7 +226,7 @@ class ModelProviderService: :return: """ provider_configuration = self._get_provider_configuration(tenant_id, provider) - return provider_configuration.get_custom_model_credential( + return provider_configuration.get_custom_model_credential( # type: ignore model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id )