mirror of https://github.com/langgenius/dify.git
fix mypy
This commit is contained in:
parent
ad37863183
commit
d4c2003450
|
|
@ -4,7 +4,7 @@ import logging
|
|||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import func, select
|
||||
|
|
@ -270,7 +270,7 @@ class ProviderConfiguration(BaseModel):
|
|||
return self._get_specific_provider_credential(credential_id)
|
||||
|
||||
# Default behavior: return current active provider credentials
|
||||
credentials = self.custom_configuration.provider.credentials
|
||||
credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
|
||||
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
|
|
@ -290,7 +290,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
|
||||
def _validate(s: Session) -> tuple[Provider | None, dict]:
|
||||
def _validate(s: Session) -> dict:
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
|
|
@ -452,7 +452,11 @@ class ProviderConfiguration(BaseModel):
|
|||
raise
|
||||
|
||||
def _update_load_balancing_configs_with_credential(
|
||||
self, credential_id: str, credential_record: dict, credential_source: str, session: Session
|
||||
self,
|
||||
credential_id: str,
|
||||
credential_record: ProviderCredential | ProviderModelCredential,
|
||||
credential_source: str,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Update load balancing configurations that reference the given credential_id.
|
||||
|
|
@ -513,13 +517,13 @@ class ProviderConfiguration(BaseModel):
|
|||
raise ValueError("Credential record not found.")
|
||||
|
||||
# Check if this credential is used in load balancing configs
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||
)
|
||||
lb_configs_using_credential = session.execute(stmt).scalars().all()
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
try:
|
||||
for lb_config in lb_configs_using_credential:
|
||||
lb_credentials_cache = ProviderCredentialsCache(
|
||||
|
|
@ -541,11 +545,11 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
# Check available credentials count BEFORE deleting
|
||||
# if this is the last credential, we need to delete the provider record
|
||||
stmt = select(func.count(ProviderCredential.id)).where(
|
||||
count_stmt = select(func.count(ProviderCredential.id)).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name == self.provider.provider,
|
||||
)
|
||||
available_credentials_count = session.execute(stmt).scalar()
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
|
||||
if provider_record and available_credentials_count <= 1:
|
||||
|
|
@ -644,14 +648,12 @@ class ProviderConfiguration(BaseModel):
|
|||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == model_record.provider_name,
|
||||
ProviderModelCredential.model_name == model_record.model_name,
|
||||
ProviderModelCredential.model_type == model_record.model_type,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
|
||||
credential = session.execute(stmt).scalar_one_or_none()
|
||||
|
|
@ -722,6 +724,7 @@ class ProviderConfiguration(BaseModel):
|
|||
if self.provider.model_credential_schema
|
||||
else [],
|
||||
)
|
||||
return None
|
||||
|
||||
def validate_custom_model_credentials(
|
||||
self,
|
||||
|
|
@ -730,7 +733,7 @@ class ProviderConfiguration(BaseModel):
|
|||
credentials: dict,
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
) -> tuple[ProviderModel | None, dict]:
|
||||
) -> dict:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
|
|
@ -741,7 +744,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:return:
|
||||
"""
|
||||
|
||||
def _validate(s: Session) -> tuple[ProviderModel | None, dict]:
|
||||
def _validate(s: Session) -> dict:
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
|
|
@ -938,13 +941,13 @@ class ProviderConfiguration(BaseModel):
|
|||
if not credential_record:
|
||||
raise ValueError("Credential record not found.")
|
||||
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.credential_id == credential_id,
|
||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||
)
|
||||
lb_configs_using_credential = session.execute(stmt).scalars().all()
|
||||
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
|
||||
|
||||
try:
|
||||
for lb_config in lb_configs_using_credential:
|
||||
|
|
@ -966,13 +969,13 @@ class ProviderConfiguration(BaseModel):
|
|||
|
||||
# Check available credentials count BEFORE deleting
|
||||
# if this is the last credential, we need to delete the custom model record
|
||||
stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name == self.provider.provider,
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
available_credentials_count = session.execute(stmt).scalar()
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
|
||||
if provider_model_record and available_credentials_count <= 1:
|
||||
|
|
@ -1161,7 +1164,7 @@ class ProviderConfiguration(BaseModel):
|
|||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
load_balancing_config_count = session.execute(stmt).scalar()
|
||||
load_balancing_config_count = session.execute(stmt).scalar() or 0
|
||||
if load_balancing_config_count <= 1:
|
||||
raise ValueError("Model load balancing configuration must be more than 1.")
|
||||
|
||||
|
|
@ -1264,8 +1267,8 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_name=self.provider.provider,
|
||||
preferred_provider_type=provider_type.value,
|
||||
)
|
||||
session.add(preferred_model_provider)
|
||||
session.commit()
|
||||
s.add(preferred_model_provider)
|
||||
s.commit()
|
||||
|
||||
if session:
|
||||
return _switch(session)
|
||||
|
|
|
|||
|
|
@ -333,6 +333,7 @@ class ModelLoadBalancingService:
|
|||
enabled = config.get("enabled")
|
||||
|
||||
if credential_id:
|
||||
credential_record: ProviderCredential | ProviderModelCredential | None = None
|
||||
if config_from == "predefined-model":
|
||||
credential_record = (
|
||||
db.session.query(ProviderCredential)
|
||||
|
|
@ -407,6 +408,7 @@ class ModelLoadBalancingService:
|
|||
|
||||
if credential_id:
|
||||
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
|
||||
assert credential_record is not None
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class ModelProviderService:
|
|||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
||||
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
"""
|
||||
|
|
@ -226,7 +226,7 @@ class ModelProviderService:
|
|||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_custom_model_credential(
|
||||
return provider_configuration.get_custom_model_credential( # type: ignore
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue