This commit is contained in:
wangxiaolei 2026-05-09 08:51:29 +08:00 committed by GitHub
commit 35ca3190ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 193 additions and 6 deletions

View File

@ -302,6 +302,196 @@ class ProviderManager:
# Return the encapsulated object
return provider_configurations
def get_single_provider_configuration(self, tenant_id: str, provider_name: str) -> ProviderConfiguration | None:
"""
Get single provider configuration efficiently.
Only queries data for the specified provider instead of all providers.
:param tenant_id: workspace id
:param provider_name: provider name
:return: ProviderConfiguration or None if provider not found
"""
# Get provider entity from factory
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_entity = None
for entity in model_provider_factory.get_providers():
if entity.provider == provider_name:
provider_entity = entity
break
if not provider_entity:
return None
# Handle include, exclude filtering
if is_filtered(
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
return None
# Query only records for this specific provider
provider_records = self._get_provider_records(tenant_id, provider_name)
provider_model_records = self._get_provider_model_records(tenant_id, provider_name)
provider_model_credentials = self._get_provider_model_credentials_single(tenant_id, provider_name)
provider_model_settings = self._get_provider_model_settings_single(tenant_id, provider_name)
provider_load_balancing_configs = self._get_provider_load_balancing_configs_single(tenant_id, provider_name)
preferred_provider_type_record = self._get_preferred_model_provider(tenant_id, provider_name)
provider_id_entity = ModelProviderID(provider_name)
if provider_id_entity.is_langgenius():
alt_provider_name = provider_id_entity.provider_name
if alt_provider_name != provider_name:
provider_model_records.extend(self._get_provider_model_records(tenant_id, alt_provider_name))
provider_model_credentials.extend(
self._get_provider_model_credentials_single(tenant_id, alt_provider_name)
)
if provider_model_settings is not None:
provider_model_settings.extend(
self._get_provider_model_settings_single(tenant_id, alt_provider_name) or []
)
if provider_load_balancing_configs is not None:
provider_load_balancing_configs.extend(
self._get_provider_load_balancing_configs_single(tenant_id, alt_provider_name) or []
)
# Convert to custom configuration
custom_configuration = self._to_custom_configuration(
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
)
# Convert to system configuration
system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records)
# Get preferred provider type
if preferred_provider_type_record:
preferred_provider_type = preferred_provider_type_record.preferred_provider_type
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM
elif custom_configuration.provider or custom_configuration.models:
preferred_provider_type = ProviderType.CUSTOM
elif system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM
else:
preferred_provider_type = ProviderType.CUSTOM
using_provider_type = preferred_provider_type
has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations)
if preferred_provider_type == ProviderType.SYSTEM:
if not system_configuration.enabled or not has_valid_quota:
using_provider_type = ProviderType.CUSTOM
else:
if not custom_configuration.provider and not custom_configuration.models:
if system_configuration.enabled and has_valid_quota:
using_provider_type = ProviderType.SYSTEM
# Convert to model settings
model_settings = self._to_model_settings(
provider_entity=provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=provider_load_balancing_configs,
)
provider_configuration = ProviderConfiguration(
tenant_id=tenant_id,
provider=provider_entity,
preferred_provider_type=preferred_provider_type,
using_provider_type=using_provider_type,
system_configuration=system_configuration,
custom_configuration=custom_configuration,
model_settings=model_settings,
)
provider_configuration.bind_model_runtime(self._model_runtime)
return provider_configuration
@staticmethod
def _get_provider_records(tenant_id: str, provider_name: str) -> list[Provider]:
"""Get provider records for a specific provider name."""
provider_names = ProviderManager._get_provider_names(provider_name)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name.in_(provider_names),
Provider.is_valid == True,
)
return list(session.scalars(stmt))
@staticmethod
def _get_provider_model_records(tenant_id: str, provider_name: str) -> list[ProviderModel]:
"""Get provider model records for a specific provider name."""
provider_names = ProviderManager._get_provider_names(provider_name)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModel).where(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.is_valid == True,
)
return list(session.scalars(stmt))
@staticmethod
def _get_provider_model_credentials_single(tenant_id: str, provider_name: str) -> list[ProviderModelCredential]:
"""Get provider model credentials for a specific provider name."""
provider_names = ProviderManager._get_provider_names(provider_name)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name.in_(provider_names),
)
return list(session.scalars(stmt))
@staticmethod
def _get_provider_model_settings_single(tenant_id: str, provider_name: str) -> list[ProviderModelSetting] | None:
"""Get provider model settings for a specific provider name."""
provider_names = ProviderManager._get_provider_names(provider_name)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
)
results = list(session.scalars(stmt))
return results or None
@staticmethod
def _get_provider_load_balancing_configs_single(
tenant_id: str, provider_name: str
) -> list[LoadBalancingModelConfig] | None:
"""Get provider load balancing configs for a specific provider name."""
cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled"
cache_result = redis_client.get(cache_key)
if cache_result is None:
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
redis_client.setex(cache_key, 120, str(model_load_balancing_enabled))
else:
cache_result = cache_result.decode("utf-8")
model_load_balancing_enabled = cache_result == "True"
if not model_load_balancing_enabled:
return None
provider_names = ProviderManager._get_provider_names(provider_name)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
)
return list(session.scalars(stmt))
@staticmethod
def _get_preferred_model_provider(tenant_id: str, provider_name: str) -> TenantPreferredModelProvider | None:
"""Get preferred provider type for a specific provider."""
provider_names = ProviderManager._get_provider_names(provider_name)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
)
# Return first match since provider_names may have multiple forms
results = list(session.scalars(stmt))
return results[0] if results else None
def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle:
"""
Get provider model bundle.
@ -310,10 +500,7 @@ class ProviderManager:
:param model_type: model type
:return:
"""
provider_configurations = self.get_configurations(tenant_id)
# get provider instance
provider_configuration = provider_configurations.get(provider)
provider_configuration = self.get_single_provider_configuration(tenant_id, provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")

View File

@ -338,7 +338,7 @@ def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, e
def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture):
manager = _build_provider_manager(mocker)
with patch.object(manager, "get_configurations", return_value={}):
with patch.object(manager, "get_single_provider_configuration", return_value={}):
with pytest.raises(ValueError, match="Provider openai does not exist."):
manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)
@ -452,7 +452,7 @@ def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker:
expected_bundle = Mock()
with (
patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}),
patch.object(manager, "get_single_provider_configuration", return_value=provider_configuration),
patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle,
):
result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)