diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index b290ae456e..6697b4b101 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -302,6 +302,196 @@ class ProviderManager: # Return the encapsulated object return provider_configurations + def get_single_provider_configuration(self, tenant_id: str, provider_name: str) -> ProviderConfiguration | None: + """ + Get single provider configuration efficiently. + Only queries data for the specified provider instead of all providers. + + :param tenant_id: workspace id + :param provider_name: provider name + :return: ProviderConfiguration or None if provider not found + """ + # Get provider entity from factory + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) + provider_entity = None + for entity in model_provider_factory.get_providers(): + if entity.provider == provider_name: + provider_entity = entity + break + + if not provider_entity: + return None + + # Handle include, exclude filtering + if is_filtered( + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, + ): + return None + + # Query only records for this specific provider + provider_records = self._get_provider_records(tenant_id, provider_name) + provider_model_records = self._get_provider_model_records(tenant_id, provider_name) + provider_model_credentials = self._get_provider_model_credentials_single(tenant_id, provider_name) + provider_model_settings = self._get_provider_model_settings_single(tenant_id, provider_name) + provider_load_balancing_configs = self._get_provider_load_balancing_configs_single(tenant_id, provider_name) + preferred_provider_type_record = self._get_preferred_model_provider(tenant_id, provider_name) + + provider_id_entity = ModelProviderID(provider_name) + if provider_id_entity.is_langgenius(): + alt_provider_name = provider_id_entity.provider_name + if alt_provider_name != provider_name: + provider_model_records.extend(self._get_provider_model_records(tenant_id, alt_provider_name)) + provider_model_credentials.extend( + self._get_provider_model_credentials_single(tenant_id, alt_provider_name) + ) + if provider_model_settings is not None: + provider_model_settings.extend( + self._get_provider_model_settings_single(tenant_id, alt_provider_name) or [] + ) + if provider_load_balancing_configs is not None: + provider_load_balancing_configs.extend( + self._get_provider_load_balancing_configs_single(tenant_id, alt_provider_name) or [] + ) + + # Convert to custom configuration + custom_configuration = self._to_custom_configuration( + tenant_id, provider_entity, provider_records, provider_model_records, provider_model_credentials + ) + + # Convert to system configuration + system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) + + # Get preferred provider type + if preferred_provider_type_record: + preferred_provider_type = preferred_provider_type_record.preferred_provider_type + elif dify_config.EDITION == "CLOUD" and system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM + elif custom_configuration.provider or custom_configuration.models: + preferred_provider_type = ProviderType.CUSTOM + elif system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM + else: + preferred_provider_type = ProviderType.CUSTOM + + using_provider_type = preferred_provider_type + has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations) + + if preferred_provider_type == ProviderType.SYSTEM: + if not system_configuration.enabled or not has_valid_quota: + using_provider_type = ProviderType.CUSTOM + else: + if not custom_configuration.provider and not custom_configuration.models: + if system_configuration.enabled and has_valid_quota: + using_provider_type = ProviderType.SYSTEM + + # Convert to model settings + model_settings = self._to_model_settings( + provider_entity=provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=provider_load_balancing_configs, + ) + + provider_configuration = ProviderConfiguration( + tenant_id=tenant_id, + provider=provider_entity, + preferred_provider_type=preferred_provider_type, + using_provider_type=using_provider_type, + system_configuration=system_configuration, + custom_configuration=custom_configuration, + model_settings=model_settings, + ) + provider_configuration.bind_model_runtime(self._model_runtime) + + return provider_configuration + + @staticmethod + def _get_provider_records(tenant_id: str, provider_name: str) -> list[Provider]: + """Get provider records for a specific provider name.""" + provider_names = ProviderManager._get_provider_names(provider_name) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name.in_(provider_names), + Provider.is_valid == True, + ) + return list(session.scalars(stmt)) + + @staticmethod + def _get_provider_model_records(tenant_id: str, provider_name: str) -> list[ProviderModel]: + """Get provider model records for a specific provider name.""" + provider_names = ProviderManager._get_provider_names(provider_name) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModel).where( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name.in_(provider_names), + ProviderModel.is_valid == True, + ) + return list(session.scalars(stmt)) + + @staticmethod + def _get_provider_model_credentials_single(tenant_id: str, provider_name: str) -> list[ProviderModelCredential]: + """Get provider model credentials for a specific provider name.""" + provider_names = ProviderManager._get_provider_names(provider_name) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelCredential).where( + ProviderModelCredential.tenant_id == tenant_id, + ProviderModelCredential.provider_name.in_(provider_names), + ) + return list(session.scalars(stmt)) + + @staticmethod + def _get_provider_model_settings_single(tenant_id: str, provider_name: str) -> list[ProviderModelSetting] | None: + """Get provider model settings for a specific provider name.""" + provider_names = ProviderManager._get_provider_names(provider_name) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelSetting).where( + ProviderModelSetting.tenant_id == tenant_id, + ProviderModelSetting.provider_name.in_(provider_names), + ) + results = list(session.scalars(stmt)) + return results or None + + @staticmethod + def _get_provider_load_balancing_configs_single( + tenant_id: str, provider_name: str + ) -> list[LoadBalancingModelConfig] | None: + """Get provider load balancing configs for a specific provider name.""" + cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled" + cache_result = redis_client.get(cache_key) + if cache_result is None: + model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled + redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) + else: + cache_result = cache_result.decode("utf-8") + model_load_balancing_enabled = cache_result == "True" + + if not model_load_balancing_enabled: + return None + + provider_names = ProviderManager._get_provider_names(provider_name) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(LoadBalancingModelConfig).where( + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name.in_(provider_names), + ) + return list(session.scalars(stmt)) + + @staticmethod + def _get_preferred_model_provider(tenant_id: str, provider_name: str) -> TenantPreferredModelProvider | None: + """Get preferred provider type for a specific provider.""" + provider_names = ProviderManager._get_provider_names(provider_name) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name.in_(provider_names), + ) + # Return first match since provider_names may have multiple forms + results = list(session.scalars(stmt)) + return results[0] if results else None + def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle: """ Get provider model bundle. @@ -310,10 +500,7 @@ class ProviderManager: :param model_type: model type :return: """ - provider_configurations = self.get_configurations(tenant_id) - - # get provider instance - provider_configuration = provider_configurations.get(provider) + provider_configuration = self.get_single_provider_configuration(tenant_id, provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 02f12fb3b4..f30a6a2a6c 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -338,7 +338,7 @@ def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, e def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture): manager = _build_provider_manager(mocker) - with patch.object(manager, "get_configurations", return_value={}): + with patch.object(manager, "get_single_provider_configuration", return_value={}): with pytest.raises(ValueError, match="Provider openai does not exist."): manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) @@ -452,7 +452,7 @@ def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: expected_bundle = Mock() with ( - patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}), + patch.object(manager, "get_single_provider_configuration", return_value=provider_configuration), patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle, ): result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM)