mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
Merge 38bedb8616 into 19bf36a716
This commit is contained in:
commit
35ca3190ed
@ -302,6 +302,196 @@ class ProviderManager:
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
def get_single_provider_configuration(self, tenant_id: str, provider_name: str) -> ProviderConfiguration | None:
|
||||
"""
|
||||
Get single provider configuration efficiently.
|
||||
Only queries data for the specified provider instead of all providers.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:return: ProviderConfiguration or None if provider not found
|
||||
"""
|
||||
# Get provider entity from factory
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
provider_entity = None
|
||||
for entity in model_provider_factory.get_providers():
|
||||
if entity.provider == provider_name:
|
||||
provider_entity = entity
|
||||
break
|
||||
|
||||
if not provider_entity:
|
||||
return None
|
||||
|
||||
# Handle include, exclude filtering
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
return None
|
||||
|
||||
# Query only records for this specific provider
|
||||
provider_records = self._get_provider_records(tenant_id, provider_name)
|
||||
provider_model_records = self._get_provider_model_records(tenant_id, provider_name)
|
||||
provider_model_credentials = self._get_provider_model_credentials_single(tenant_id, provider_name)
|
||||
provider_model_settings = self._get_provider_model_settings_single(tenant_id, provider_name)
|
||||
provider_load_balancing_configs = self._get_provider_load_balancing_configs_single(tenant_id, provider_name)
|
||||
preferred_provider_type_record = self._get_preferred_model_provider(tenant_id, provider_name)
|
||||
|
||||
provider_id_entity = ModelProviderID(provider_name)
|
||||
if provider_id_entity.is_langgenius():
|
||||
alt_provider_name = provider_id_entity.provider_name
|
||||
if alt_provider_name != provider_name:
|
||||
provider_model_records.extend(self._get_provider_model_records(tenant_id, alt_provider_name))
|
||||
provider_model_credentials.extend(
|
||||
self._get_provider_model_credentials_single(tenant_id, alt_provider_name)
|
||||
)
|
||||
if provider_model_settings is not None:
|
||||
provider_model_settings.extend(
|
||||
self._get_provider_model_settings_single(tenant_id, alt_provider_name) or []
|
||||
)
|
||||
if provider_load_balancing_configs is not None:
|
||||
provider_load_balancing_configs.extend(
|
||||
self._get_provider_load_balancing_configs_single(tenant_id, alt_provider_name) or []
|
||||
)
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials
|
||||
)
|
||||
|
||||
# Convert to system configuration
|
||||
system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records)
|
||||
|
||||
# Get preferred provider type
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = preferred_provider_type_record.preferred_provider_type
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
else:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
|
||||
using_provider_type = preferred_provider_type
|
||||
has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations)
|
||||
|
||||
if preferred_provider_type == ProviderType.SYSTEM:
|
||||
if not system_configuration.enabled or not has_valid_quota:
|
||||
using_provider_type = ProviderType.CUSTOM
|
||||
else:
|
||||
if not custom_configuration.provider and not custom_configuration.models:
|
||||
if system_configuration.enabled and has_valid_quota:
|
||||
using_provider_type = ProviderType.SYSTEM
|
||||
|
||||
# Convert to model settings
|
||||
model_settings = self._to_model_settings(
|
||||
provider_entity=provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=provider_load_balancing_configs,
|
||||
)
|
||||
|
||||
provider_configuration = ProviderConfiguration(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_entity,
|
||||
preferred_provider_type=preferred_provider_type,
|
||||
using_provider_type=using_provider_type,
|
||||
system_configuration=system_configuration,
|
||||
custom_configuration=custom_configuration,
|
||||
model_settings=model_settings,
|
||||
)
|
||||
provider_configuration.bind_model_runtime(self._model_runtime)
|
||||
|
||||
return provider_configuration
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_records(tenant_id: str, provider_name: str) -> list[Provider]:
|
||||
"""Get provider records for a specific provider name."""
|
||||
provider_names = ProviderManager._get_provider_names(provider_name)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name.in_(provider_names),
|
||||
Provider.is_valid == True,
|
||||
)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_model_records(tenant_id: str, provider_name: str) -> list[ProviderModel]:
|
||||
"""Get provider model records for a specific provider name."""
|
||||
provider_names = ProviderManager._get_provider_names(provider_name)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModel).where(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name.in_(provider_names),
|
||||
ProviderModel.is_valid == True,
|
||||
)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_model_credentials_single(tenant_id: str, provider_name: str) -> list[ProviderModelCredential]:
|
||||
"""Get provider model credentials for a specific provider name."""
|
||||
provider_names = ProviderManager._get_provider_names(provider_name)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(provider_names),
|
||||
)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_model_settings_single(tenant_id: str, provider_name: str) -> list[ProviderModelSetting] | None:
|
||||
"""Get provider model settings for a specific provider name."""
|
||||
provider_names = ProviderManager._get_provider_names(provider_name)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(provider_names),
|
||||
)
|
||||
results = list(session.scalars(stmt))
|
||||
return results or None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_load_balancing_configs_single(
|
||||
tenant_id: str, provider_name: str
|
||||
) -> list[LoadBalancingModelConfig] | None:
|
||||
"""Get provider load balancing configs for a specific provider name."""
|
||||
cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled"
|
||||
cache_result = redis_client.get(cache_key)
|
||||
if cache_result is None:
|
||||
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
|
||||
redis_client.setex(cache_key, 120, str(model_load_balancing_enabled))
|
||||
else:
|
||||
cache_result = cache_result.decode("utf-8")
|
||||
model_load_balancing_enabled = cache_result == "True"
|
||||
|
||||
if not model_load_balancing_enabled:
|
||||
return None
|
||||
|
||||
provider_names = ProviderManager._get_provider_names(provider_name)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
@staticmethod
|
||||
def _get_preferred_model_provider(tenant_id: str, provider_name: str) -> TenantPreferredModelProvider | None:
|
||||
"""Get preferred provider type for a specific provider."""
|
||||
provider_names = ProviderManager._get_provider_names(provider_name)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
||||
)
|
||||
# Return first match since provider_names may have multiple forms
|
||||
results = list(session.scalars(stmt))
|
||||
return results[0] if results else None
|
||||
|
||||
def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle:
|
||||
"""
|
||||
Get provider model bundle.
|
||||
@ -310,10 +500,7 @@ class ProviderManager:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get provider instance
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
provider_configuration = self.get_single_provider_configuration(tenant_id, provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
|
||||
@ -338,7 +338,7 @@ def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, e
|
||||
def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture):
|
||||
manager = _build_provider_manager(mocker)
|
||||
|
||||
with patch.object(manager, "get_configurations", return_value={}):
|
||||
with patch.object(manager, "get_single_provider_configuration", return_value={}):
|
||||
with pytest.raises(ValueError, match="Provider openai does not exist."):
|
||||
manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)
|
||||
|
||||
@ -452,7 +452,7 @@ def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker:
|
||||
expected_bundle = Mock()
|
||||
|
||||
with (
|
||||
patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}),
|
||||
patch.object(manager, "get_single_provider_configuration", return_value=provider_configuration),
|
||||
patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle,
|
||||
):
|
||||
result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user