diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 5631caa1a59..765268f7a0e 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -4,6 +4,7 @@ from copy import deepcopy from typing import Any from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelWithProviderEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager @@ -19,6 +20,10 @@ class DifyCredentialsProvider: Fetched credentials are stored in :attr:`credentials_cache` and reused for subsequent ``fetch`` calls for the same ``(provider_name, model_name)``. + The matching validated provider model is cached alongside the credentials so + follow-up workflow startup checks can reuse the earlier provider lookup + instead of resolving the same model metadata a second time. + Because of that cache, a single instance can return stale credentials after the tenant or provider configuration changes (e.g. API key rotation). @@ -30,6 +35,7 @@ class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager credentials_cache: dict[tuple[str, str], dict[str, Any]] + provider_model_cache: dict[tuple[str, str], ModelWithProviderEntity] def __init__( self, @@ -45,12 +51,21 @@ class DifyCredentialsProvider: ) self.provider_manager = provider_manager self.credentials_cache = {} + self.provider_model_cache = {} + + def get_cached_provider_model(self, provider_name: str, model_name: str) -> ModelWithProviderEntity | None: + provider_model = self.provider_model_cache.get((provider_name, model_name)) + if provider_model is None: + return None + + return provider_model.model_copy(deep=True) def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: if (provider_name, model_name) in self.credentials_cache: return deepcopy(self.credentials_cache[(provider_name, model_name)]) provider_configurations = self.provider_manager.get_configurations(self.tenant_id) + provider_configuration = provider_configurations.get(provider_name) if not provider_configuration: raise ValueError(f"Provider {provider_name} does not exist.") @@ -65,6 +80,7 @@ class DifyCredentialsProvider: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials) + self.provider_model_cache[(provider_name, model_name)] = provider_model.model_copy(deep=True) return credentials @@ -142,13 +158,21 @@ def fetch_model_config( model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name) provider_model_bundle = model_instance.provider_model_bundle - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, - ) + provider_model = None + if isinstance(credentials_provider, DifyCredentialsProvider): + provider_model = credentials_provider.get_cached_provider_model( + provider_name=node_data_model.provider, + model_name=node_data_model.name, + ) + if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") - provider_model.raise_for_status() + provider_model = provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, + model_type=ModelType.LLM, + ) + if provider_model is None: + raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") + provider_model.raise_for_status() model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) if model_schema is None: diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 91c46d07a80..95375a8cbb2 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -69,6 +69,11 @@ class ProviderConfiguration(BaseModel): nested schema and model lookups reuse the caller scope that was already resolved by the composition layer. + The ``provider`` field already contains the resolved provider schema that + was used to build this configuration. Reuse that schema for nested model + lookups instead of refetching the full provider catalog from the runtime on + every request-scoped lookup. + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ @@ -83,15 +88,19 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) _bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None) + _cached_provider_schema: ProviderEntity | None = PrivateAttr(default=None) + _original_provider_configurate_methods: tuple[ConfigurateMethod, ...] = PrivateAttr(default_factory=tuple) @model_validator(mode="after") def _(self): + self._original_provider_configurate_methods = tuple(self.provider.configurate_methods) + if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in self.provider.configurate_methods: original_provider_configurate_methods[self.provider.provider].append(configurate_method) - if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + if list(self._original_provider_configurate_methods) == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: if ( any( len(quota_configuration.restrict_models) > 0 @@ -105,6 +114,29 @@ class ProviderConfiguration(BaseModel): def bind_model_runtime(self, model_runtime: ModelRuntime) -> None: """Attach the already-composed runtime for request-bound call chains.""" self._bound_model_runtime = model_runtime + self._cached_provider_schema = self.provider + + def _get_original_provider_configurate_methods(self) -> list[ConfigurateMethod]: + return list(self._original_provider_configurate_methods) + + def _get_provider_schema(self, *, model_provider_factory: ModelProviderFactory | None = None) -> ProviderEntity: + """Resolve the provider schema lazily while preserving bound-runtime reuse.""" + if self._cached_provider_schema is None: + if self.provider.models: + self._cached_provider_schema = self.provider + else: + provider_factory = model_provider_factory or self.get_model_provider_factory() + self._cached_provider_schema = provider_factory.get_provider_schema(provider=self.provider.provider) + + return self._cached_provider_schema + + def _get_model_runtime(self) -> ModelRuntime: + """Return the runtime aligned with this request-scoped configuration.""" + if self._bound_model_runtime is not None: + return self._bound_model_runtime + + model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id) + return model_assembly.model_runtime def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]: """Resolve a provider factory that stays aligned with the runtime used by the caller.""" @@ -153,7 +185,6 @@ class ProviderConfiguration(BaseModel): and restrict_model.base_model_name ): copy_credentials["base_model_name"] = restrict_model.base_model_name - return copy_credentials else: credentials = None @@ -189,7 +220,6 @@ class ProviderConfiguration(BaseModel): provider=self.provider.provider, credential_type=PluginCredentialType.MODEL, ) - return credentials def get_system_configuration_status(self) -> SystemConfigurationStatus | None: @@ -1399,8 +1429,13 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_runtime, model_provider_factory = self._get_runtime_and_provider_factory() - provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider) + if self._bound_model_runtime is not None: + model_runtime = self._bound_model_runtime + else: + model_runtime, _ = self._get_runtime_and_provider_factory() + + provider_schema = self._cached_provider_schema or self.provider + return create_model_type_instance( runtime=model_runtime, provider_schema=provider_schema, @@ -1410,12 +1445,13 @@ class ProviderConfiguration(BaseModel): def get_model_schema( self, model_type: ModelType, model: str, credentials: dict[str, Any] | None ) -> AIModelEntity | None: - """ - Get model schema - """ - model_provider_factory = self.get_model_provider_factory() - return model_provider_factory.get_model_schema( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + """Get model schema with the request-bound runtime and canonical provider id.""" + model_runtime = self._get_model_runtime() + return model_runtime.get_model_schema( + provider=self.provider.provider, + model_type=model_type, + model=model, + credentials=credentials or {}, ) def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None): @@ -1515,8 +1551,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = self.get_model_provider_factory() - provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) + provider_schema = self._get_provider_schema() model_types: list[ModelType] = [] if model_type: @@ -1531,7 +1566,10 @@ class ProviderConfiguration(BaseModel): if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( - model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map + model_types=model_types, + provider_schema=provider_schema, + model_setting_map=model_setting_map, + model=model, ) else: provider_models = self._get_custom_provider_models( @@ -1573,6 +1611,7 @@ class ProviderConfiguration(BaseModel): model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], + model: str | None = None, ) -> list[ModelWithProviderEntity]: """ Get system provider models. @@ -1587,6 +1626,8 @@ class ProviderConfiguration(BaseModel): for m in provider_schema.models: if m.model_type != model_type: continue + if model and m.model != model: + continue status = ModelStatus.ACTIVE if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: @@ -1608,13 +1649,9 @@ class ProviderConfiguration(BaseModel): ) ) - if self.provider.provider not in original_provider_configurate_methods: - original_provider_configurate_methods[self.provider.provider] = [] - for configurate_method in provider_schema.configurate_methods: - original_provider_configurate_methods[self.provider.provider].append(configurate_method) - + original_configurate_methods = self._get_original_provider_configurate_methods() should_use_custom_model = False - if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: + if original_configurate_methods == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: should_use_custom_model = True for quota_configuration in self.system_configuration.quota_configurations: @@ -1626,11 +1663,12 @@ class ProviderConfiguration(BaseModel): break if should_use_custom_model: - if original_provider_configurate_methods[self.provider.provider] == [ - ConfigurateMethod.CUSTOMIZABLE_MODEL - ]: + if original_configurate_methods == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: # only customizable model for restrict_model in restrict_models: + if model and restrict_model.model != model: + continue + copy_credentials = ( self.system_configuration.credentials.copy() if self.system_configuration.credentials @@ -1680,11 +1718,11 @@ class ProviderConfiguration(BaseModel): # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] - for model in provider_models: - if model.model_type == ModelType.LLM and model.model not in restrict_model_names: - model.status = ModelStatus.NO_PERMISSION + for provider_model in provider_models: + if provider_model.model_type == ModelType.LLM and provider_model.model not in restrict_model_names: + provider_model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: - model.status = ModelStatus.QUOTA_EXCEEDED + provider_model.status = ModelStatus.QUOTA_EXCEEDED return provider_models @@ -1709,6 +1747,13 @@ class ProviderConfiguration(BaseModel): if self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials + requested_predefined_model = False + if model: + requested_predefined_model = any( + predefined_model.model_type in model_types and predefined_model.model == model + for predefined_model in provider_schema.models + ) + for model_type in model_types: if model_type not in self.provider.supported_model_types: continue @@ -1716,6 +1761,8 @@ class ProviderConfiguration(BaseModel): for m in provider_schema.models: if m.model_type != model_type: continue + if requested_predefined_model and model and m.model != model: + continue status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE load_balancing_enabled = False diff --git a/api/core/plugin/plugin_service.py b/api/core/plugin/plugin_service.py index 50b35afbcd0..79c372690e6 100644 --- a/api/core/plugin/plugin_service.py +++ b/api/core/plugin/plugin_service.py @@ -13,8 +13,10 @@ metadata. """ import logging +import time from collections.abc import Mapping, Sequence from mimetypes import guess_type +from typing import ClassVar from pydantic import BaseModel, TypeAdapter, ValidationError from redis import RedisError @@ -64,6 +66,8 @@ _provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list class PluginService: + _plugin_model_providers_memory_cache: ClassVar[dict[str, tuple[int, float, tuple[ProviderEntity, ...]]]] = {} + class LatestPluginCache(BaseModel): plugin_id: str version: str @@ -75,14 +79,22 @@ class PluginService: REDIS_KEY_PREFIX = "plugin_service:latest_plugin:" REDIS_TTL = 60 * 5 # 5 minutes PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:" + PLUGIN_MODEL_PROVIDERS_GENERATION_REDIS_KEY_PREFIX = "plugin_model_providers_generation:tenant_id:" PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed) # Mirror the detail-panel endpoint query size so list reconciliation and # the visible endpoint drawer exercise the same daemon pagination path. ENDPOINT_RECONCILIATION_PAGE_SIZE = 100 @classmethod - def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str: - return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}" + def _get_plugin_model_providers_cache_key(cls, tenant_id: str, generation: int | None = None) -> str: + if generation is None: + return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}" + + return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}:generation:{generation}" + + @classmethod + def _get_plugin_model_providers_generation_cache_key(cls, tenant_id: str) -> str: + return f"{cls.PLUGIN_MODEL_PROVIDERS_GENERATION_REDIS_KEY_PREFIX}{tenant_id}" @staticmethod def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str: @@ -115,29 +127,133 @@ class PluginService: return declaration @classmethod - def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None: - cache_key = cls._get_plugin_model_providers_cache_key(tenant_id) + def _copy_provider_entities(cls, providers: Sequence[ProviderEntity]) -> tuple[ProviderEntity, ...]: + return tuple(provider.model_copy(deep=True) for provider in providers) + + @classmethod + def _load_plugin_model_providers_generation(cls, tenant_id: str) -> int | None: + cache_key = cls._get_plugin_model_providers_generation_cache_key(tenant_id) try: - cached_providers = redis_client.get(cache_key) + cached_generation = redis_client.get(cache_key) + except (RedisError, RuntimeError): + logger.warning("Failed to read plugin model provider generation for tenant %s.", tenant_id, exc_info=True) + return None + + if cached_generation is None: + return 0 + + try: + return int(cached_generation) + except (TypeError, ValueError): + logger.warning( + "Invalid plugin model provider generation for tenant %s; deleting cache marker.", + tenant_id, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError): + logger.warning( + "Failed to delete invalid plugin model provider generation for tenant %s.", + tenant_id, + exc_info=True, + ) + return None + + @classmethod + def _load_in_memory_plugin_model_providers( + cls, memory_cache_key: str, generation: int + ) -> tuple[ProviderEntity, ...] | None: + cached_entry = cls._plugin_model_providers_memory_cache.get(memory_cache_key) + if cached_entry is None: + return None + + cached_generation, expires_at, providers = cached_entry + if cached_generation != generation or time.monotonic() >= expires_at: + cls._plugin_model_providers_memory_cache.pop(memory_cache_key, None) + return None + + return cls._copy_provider_entities(providers) + + @classmethod + def _store_in_memory_plugin_model_providers( + cls, memory_cache_key: str, generation: int, providers: Sequence[ProviderEntity] + ) -> None: + ttl = dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL + if ttl <= 0: + cls._plugin_model_providers_memory_cache.pop(memory_cache_key, None) + return + + cls._plugin_model_providers_memory_cache[memory_cache_key] = ( + generation, + time.monotonic() + ttl, + cls._copy_provider_entities(providers), + ) + + @classmethod + def _load_cached_plugin_model_providers( + cls, tenant_id: str, *, client: PluginModelClient | None = None + ) -> tuple[ProviderEntity, ...] | None: + generation = cls._load_plugin_model_providers_generation(tenant_id) + if generation is not None: + in_memory_cached_providers = cls._load_in_memory_plugin_model_providers(tenant_id, generation) + if in_memory_cached_providers is not None: + return in_memory_cached_providers + + cache_keys = [] + if generation is not None: + cache_keys.append(cls._get_plugin_model_providers_cache_key(tenant_id, generation)) + if generation == 0: + cache_keys.append(cls._get_plugin_model_providers_cache_key(tenant_id)) + + if not cache_keys: + return None + + try: + cached_provider_entries = redis_client.mget(cache_keys) except (RedisError, RuntimeError): logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True) return None - if not cached_providers: + if len(cached_provider_entries) != len(cache_keys): + logger.warning( + "Unexpected cached plugin model providers response size for tenant %s.", + tenant_id, + ) return None - try: - return tuple(_provider_entities_adapter.validate_json(cached_providers)) - except (TypeError, ValueError, ValidationError): - logger.warning( - "Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True - ) - cls.invalidate_plugin_model_providers_cache(tenant_id) - return None + for cache_key, cached_providers in zip(cache_keys, cached_provider_entries): + if not cached_providers: + continue + + try: + providers = tuple(_provider_entities_adapter.validate_json(cached_providers)) + if generation is not None: + cls._store_in_memory_plugin_model_providers(tenant_id, generation, providers) + return providers + except (TypeError, ValueError, ValidationError): + logger.warning( + "Invalid cached plugin model providers for tenant %s; deleting cache key %s.", + tenant_id, + cache_key, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError): + logger.warning( + "Failed to delete invalid cached plugin model providers for tenant %s.", + tenant_id, + exc_info=True, + ) + + return None @classmethod - def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None: - cache_key = cls._get_plugin_model_providers_cache_key(tenant_id) + def _store_cached_plugin_model_providers( + cls, tenant_id: str, generation: int, providers: Sequence[ProviderEntity] + ) -> None: + cache_key = cls._get_plugin_model_providers_cache_key(tenant_id, generation) try: payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8") redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload) @@ -146,9 +262,15 @@ class PluginService: @classmethod def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None: - """Delete the tenant-scoped plugin model provider list cache.""" + """Invalidate tenant-scoped provider metadata across Redis and worker-local mirrors.""" + cls._plugin_model_providers_memory_cache.pop(tenant_id, None) + cache_key = cls._get_plugin_model_providers_cache_key(tenant_id) + generation_key = cls._get_plugin_model_providers_generation_cache_key(tenant_id) try: - redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id)) + pipe = redis_client.pipeline(transaction=False) + pipe.delete(cache_key) + pipe.incr(generation_key) + pipe.execute() except (RedisError, RuntimeError): logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True) @@ -163,7 +285,7 @@ class PluginService: are intentionally owned by this service so tenant isolation and cache expiry are handled in one place. """ - cached_providers = cls._load_cached_plugin_model_providers(tenant_id) + cached_providers = cls._load_cached_plugin_model_providers(tenant_id, client=client) if cached_providers is not None: return cached_providers @@ -171,7 +293,12 @@ class PluginService: providers = tuple( cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id) ) - cls._store_cached_plugin_model_providers(tenant_id, providers) + if not providers: + return providers + generation = cls._load_plugin_model_providers_generation(tenant_id) + if generation is not None: + cls._store_in_memory_plugin_model_providers(tenant_id, generation, providers) + cls._store_cached_plugin_model_providers(tenant_id, generation, providers) return providers @staticmethod diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 1b714d68307..07ba9314977 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -433,10 +433,9 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_model_type_instance = Mock() mock_schema = _build_ai_model("gpt-4o") mock_factory = Mock() - mock_factory.get_provider_schema.return_value = configuration.provider - mock_factory.get_model_schema.return_value = mock_schema mock_assembly = Mock() mock_assembly.model_runtime = Mock() + mock_assembly.model_runtime.get_model_schema.return_value = mock_schema mock_assembly.model_provider_factory = mock_factory with ( @@ -455,13 +454,12 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema assert mock_assembly_builder.call_count == 2 - mock_factory.get_provider_schema.assert_called_once_with(provider="openai") mock_model_builder.assert_called_once_with( runtime=mock_assembly.model_runtime, provider_schema=configuration.provider, model_type=ModelType.LLM, ) - mock_factory.get_model_schema.assert_called_once_with( + mock_assembly.model_runtime.get_model_schema.assert_called_once_with( provider="openai", model_type=ModelType.LLM, model="gpt-4o", @@ -472,18 +470,13 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None: configuration = _build_provider_configuration() bound_runtime = Mock() + bound_runtime.get_model_schema.return_value = _build_ai_model("gpt-4o") configuration.bind_model_runtime(bound_runtime) mock_model_type_instance = Mock() - mock_schema = _build_ai_model("gpt-4o") - mock_factory = Mock() - mock_factory.get_provider_schema.return_value = configuration.provider - mock_factory.get_model_schema.return_value = mock_schema with ( - patch( - "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory - ) as mock_factory_cls, + patch("core.entities.provider_configuration.ModelProviderFactory") as mock_factory_cls, patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder, patch( "core.entities.provider_configuration.create_model_type_instance", @@ -494,16 +487,20 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance - assert model_schema is mock_schema - assert mock_factory_cls.call_count == 2 - mock_factory_cls.assert_called_with(runtime=bound_runtime) + assert model_schema == bound_runtime.get_model_schema.return_value + mock_factory_cls.assert_not_called() mock_assembly_builder.assert_not_called() - mock_factory.get_provider_schema.assert_called_once_with(provider="openai") mock_model_builder.assert_called_once_with( runtime=bound_runtime, provider_schema=configuration.provider, model_type=ModelType.LLM, ) + bound_runtime.get_model_schema.assert_called_once_with( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"api_key": "x"}, + ) def test_get_provider_model_returns_none_when_model_not_found() -> None: @@ -544,6 +541,99 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N assert [model.model for model in active_models] == ["b-model"] +def test_get_provider_models_system_filters_requested_model() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("a-model"), _build_ai_model("target-model"), _build_ai_model("b-model")], + ) + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = provider_schema + + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), + ): + models = configuration.get_provider_models( + model_type=ModelType.LLM, + only_active=False, + model="target-model", + ) + + assert [model.model for model in models] == ["target-model"] + + +def test_get_provider_models_system_customizable_filters_requested_restricted_model() -> None: + provider = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + ) + system_configuration = SystemConfiguration( + enabled=True, + credentials={"api_key": "test-key"}, + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=1_000, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="target-model", base_model_name="base-model", model_type=ModelType.LLM), + RestrictModel(model="other-model", base_model_name="base-model", model_type=ModelType.LLM), + ], + ) + ], + ) + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[], + ) + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = provider_schema + + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + configuration = ProviderConfiguration( + tenant_id="tenant-1", + provider=provider, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + custom_configuration=CustomConfiguration(provider=None, models=[]), + model_settings=[], + ) + + with ( + patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), + ), + patch.object( + ProviderConfiguration, + "get_model_schema", + side_effect=lambda *args, **kwargs: _build_ai_model(kwargs["model"]), + ) as mock_get_model_schema, + ): + models = configuration.get_provider_models( + model_type=ModelType.LLM, + only_active=False, + model="target-model", + ) + + assert [model.model for model in models] == ["target-model"] + mock_get_model_schema.assert_called_once() + assert mock_get_model_schema.call_args.kwargs["model"] == "target-model" + + def test_get_custom_provider_models_sets_status_for_removed_credentials_and_invalid_lb_configs() -> None: configuration = _build_provider_configuration() configuration.using_provider_type = ProviderType.CUSTOM @@ -611,6 +701,48 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva assert invalid_lb_map["custom-model"] is True +def test_get_custom_provider_models_filters_requested_base_model() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("base-model"), _build_ai_model("target-model")], + ) + + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={}, + model="target-model", + ) + + assert [model.model for model in models] == ["target-model"] + + +def test_get_provider_models_reuses_cached_provider_schema() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("a-model"), _build_ai_model("b-model")], + ) + configuration.provider = provider_schema + + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + ) as mock_assembly_builder: + configuration.get_provider_models(model_type=ModelType.LLM, model="a-model") + configuration.get_provider_models(model_type=ModelType.LLM, model="b-model") + + mock_assembly_builder.assert_not_called() + + def test_validator_adds_predefined_model_for_customizable_provider_with_restrictions() -> None: provider = ProviderEntity( provider="openai", @@ -1402,25 +1534,22 @@ def test_system_and_custom_provider_model_helpers_cover_remaining_skip_paths() - return _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING) return _build_ai_model("target") - with patch( - "core.entities.provider_configuration.original_provider_configurate_methods", - {"openai": [ConfigurateMethod.CUSTOMIZABLE_MODEL]}, - ): - with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema): - system_models = configuration._get_system_provider_models( - model_types=[ModelType.LLM], - provider_schema=provider_schema, - model_setting_map={ - ModelType.LLM: { - "target": ModelSettings( - model="target", - model_type=ModelType.LLM, - enabled=False, - load_balancing_configs=[], - ) - } - }, - ) + configuration._original_provider_configurate_methods = (ConfigurateMethod.CUSTOMIZABLE_MODEL,) + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema): + system_models = configuration._get_system_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={ + ModelType.LLM: { + "target": ModelSettings( + model="target", + model_type=ModelType.LLM, + enabled=False, + load_balancing_configs=[], + ) + } + }, + ) assert any(model.model == "target" and model.status == ModelStatus.DISABLED for model in system_models) configuration.using_provider_type = ProviderType.CUSTOM diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index f9abc7d02a1..3fd885b28fb 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -28,6 +28,9 @@ class _FakeRedis: def get(self, key: str) -> str | None: return self._values.get(key) + def mget(self, keys: list[str]) -> list[str | None]: + return [self.get(key) for key in keys] + def setex(self, key: str, ttl: int, value: str) -> None: self._values[key] = value self.setex_calls.append((key, ttl, value)) @@ -36,6 +39,13 @@ class _FakeRedis: self._values.pop(key, None) +@pytest.fixture(autouse=True) +def clear_plugin_model_provider_memory_cache() -> None: + PluginService._plugin_model_providers_memory_cache.clear() + yield + PluginService._plugin_model_providers_memory_cache.clear() + + def _build_model_schema() -> AIModelEntity: return AIModelEntity( model="gpt-4o-mini", @@ -329,6 +339,7 @@ class TestPluginModelRuntime: "redis_client", SimpleNamespace( get=Mock(return_value=None), + mget=Mock(return_value=[None, None]), delete=Mock(), setex=Mock(), ), diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fb50723402d..d1fca6564b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -345,6 +345,62 @@ def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_confi provider_model.raise_for_status.assert_called_once() +def test_fetch_model_config_reuses_validated_provider_model_from_dify_credentials_provider( + model_config: ModelConfigWithCredentialsEntity, +): + mock_provider_manager = mock.MagicMock() + mock_configurations = mock.MagicMock() + mock_provider_configuration = mock.MagicMock() + mock_provider_model = mock.MagicMock() + mock_model_factory = mock.MagicMock(spec=DifyModelFactory) + + mock_configurations.get.return_value = mock_provider_configuration + mock_provider_configuration.get_provider_model.return_value = mock_provider_model + mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} + mock_provider_manager.get_configurations.return_value = mock_configurations + + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + credentials_provider = DifyCredentialsProvider( + run_context=run_context, + provider_manager=mock_provider_manager, + ) + + model_instance = mock.MagicMock( + model_type_instance=model_config.provider_model_bundle.model_type_instance, + provider_model_bundle=model_config.provider_model_bundle, + ) + mock_model_factory.init_model_instance.return_value = model_instance + + with mock.patch.object( + model_instance.model_type_instance.__class__, + "get_model_schema", + return_value=model_config.model_schema, + autospec=True, + ): + fetch_model_config( + node_data_model=ModelConfig( + provider="openai", + name="gpt-3.5-turbo", + mode="chat", + completion_params={}, + ), + credentials_provider=credentials_provider, + model_factory=mock_model_factory, + ) + + mock_provider_configuration.get_provider_model.assert_called_once_with( + model_type=ModelType.LLM, + model="gpt-3.5-turbo", + ) + mock_provider_model.raise_for_status.assert_called_once() + + def test_dify_model_access_adapters_call_managers(): mock_provider_manager = mock.MagicMock() mock_model_manager = mock.MagicMock() diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/unit_tests/services/plugin/test_plugin_service.py index f7b89401180..fca05f94fc7 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_service.py @@ -1,8 +1,9 @@ import datetime import uuid from types import SimpleNamespace -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, call, patch +import pytest from pydantic import TypeAdapter from redis import RedisError @@ -13,6 +14,15 @@ from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, MODULE = "core.plugin.plugin_service" +@pytest.fixture(autouse=True) +def clear_plugin_model_provider_memory_cache() -> None: + from core.plugin.plugin_service import PluginService + + PluginService._plugin_model_providers_memory_cache.clear() + yield + PluginService._plugin_model_providers_memory_cache.clear() + + class _FakeSession: def __init__(self) -> None: self.execute = Mock() @@ -68,6 +78,17 @@ def _build_install_task(*, task_id: str = "task-1", status: PluginInstallTaskSta ) +def _provider_cache_key(tenant_id: str, generation: int | None = None) -> str: + if generation is None: + return f"plugin_model_providers:tenant_id:{tenant_id}" + + return f"plugin_model_providers:tenant_id:{tenant_id}:generation:{generation}" + + +def _provider_generation_key(tenant_id: str) -> str: + return f"plugin_model_providers_generation:tenant_id:{tenant_id}" + + class TestFetchLatestPluginVersion: def test_skips_marketplace_fetch_when_disabled(self) -> None: """Cache misses stay None; marketplace is never called when disabled.""" @@ -120,9 +141,13 @@ class TestPluginModelProviderCache: """A valid tenant cache entry is reused across runtime calls without plugin daemon access.""" cached_provider = _build_provider_entity() cached_payload = TypeAdapter(list[ProviderEntity]).dump_json([cached_provider]).decode("utf-8") + generation_key = _provider_generation_key("tenant-1") + cache_key = _provider_cache_key("tenant-1", 0) + legacy_cache_key = _provider_cache_key("tenant-1") with patch(f"{MODULE}.redis_client") as redis_client: - redis_client.get.return_value = cached_payload + redis_client.get.return_value = None + redis_client.mget.return_value = [cached_payload, None] from core.plugin.plugin_service import PluginService @@ -132,14 +157,20 @@ class TestPluginModelProviderCache: assert [provider.provider for provider in result] == ["langgenius/openai/openai"] client.fetch_model_providers.assert_not_called() redis_client.setex.assert_not_called() + redis_client.get.assert_called_once_with(generation_key) + redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key]) def test_fetch_plugin_model_providers_deletes_invalid_cache_and_refetches(self) -> None: - """Invalid cache payloads are tenant-scoped invalidated before falling back to the daemon.""" + """Invalid generation-scoped cache payloads are removed before falling back to the daemon.""" + generation_key = _provider_generation_key("tenant-1") + cache_key = _provider_cache_key("tenant-1", 0) + legacy_cache_key = _provider_cache_key("tenant-1") with ( patch(f"{MODULE}.redis_client") as redis_client, patch(f"{MODULE}.dify_config") as mock_config, ): - redis_client.get.return_value = "not-json" + redis_client.get.side_effect = [None, None] + redis_client.mget.return_value = ["not-json", None] mock_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL = 86400 client = Mock() client.fetch_model_providers.return_value = [_build_plugin_model_provider()] @@ -148,12 +179,13 @@ class TestPluginModelProviderCache: result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) - cache_key = "plugin_model_providers:tenant_id:tenant-1" redis_client.delete.assert_called_once_with(cache_key) redis_client.setex.assert_called_once() assert redis_client.setex.call_args.args[0] == cache_key assert redis_client.setex.call_args.args[1] == 86400 assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + redis_client.get.assert_has_calls([call(generation_key), call(generation_key)]) + redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key]) def test_fetch_plugin_model_providers_refetches_when_cache_read_fails(self) -> None: """Redis read failures do not block provider discovery for the tenant.""" @@ -169,10 +201,29 @@ class TestPluginModelProviderCache: client.fetch_model_providers.assert_called_once_with("tenant-1") assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + def test_fetch_plugin_model_providers_refetches_when_cached_payload_batch_read_fails(self) -> None: + """Redis mget failures do not block provider discovery for the tenant.""" + cache_key = _provider_cache_key("tenant-1", 0) + legacy_cache_key = _provider_cache_key("tenant-1") + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.return_value = None + redis_client.mget.side_effect = RedisError("redis unavailable") + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + + from core.plugin.plugin_service import PluginService + + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + client.fetch_model_providers.assert_called_once_with("tenant-1") + redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key]) + assert [provider.provider for provider in result] == ["langgenius/openai/openai"] + def test_fetch_plugin_model_providers_returns_fresh_result_when_cache_write_fails(self) -> None: """Redis write failures are non-fatal after fresh provider data has been fetched.""" with patch(f"{MODULE}.redis_client") as redis_client: redis_client.get.return_value = None + redis_client.mget.return_value = [None, None] redis_client.setex.side_effect = RedisError("redis unavailable") client = Mock() client.fetch_model_providers.return_value = [_build_plugin_model_provider()] @@ -191,6 +242,7 @@ class TestPluginModelProviderCache: patch(f"{MODULE}.PluginModelClient") as client_cls, ): redis_client.get.return_value = None + redis_client.mget.return_value = [None, None] client = client_cls.return_value client.fetch_model_providers.return_value = [_build_plugin_model_provider()] @@ -202,23 +254,98 @@ class TestPluginModelProviderCache: client.fetch_model_providers.assert_called_once_with("tenant-1") assert [provider.provider for provider in result] == ["langgenius/openai/openai"] - def test_invalidate_plugin_model_providers_cache_uses_tenant_cache_key(self) -> None: - with patch(f"{MODULE}.redis_client") as redis_client: + def test_fetch_plugin_model_providers_reuses_process_local_cache(self) -> None: + generation_key = _provider_generation_key("tenant-1") + with ( + patch(f"{MODULE}.redis_client") as redis_client, + patch(f"{MODULE}.PluginModelClient") as client_cls, + ): + redis_client.get.side_effect = [None, None, None] + redis_client.mget.return_value = [None, None] + client = client_cls.return_value + client.fetch_model_providers.return_value = [_build_plugin_model_provider()] + from core.plugin.plugin_service import PluginService - PluginService.invalidate_plugin_model_providers_cache("tenant-1") + first_result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1") + redis_client.get.reset_mock() + redis_client.mget.reset_mock() + redis_client.setex.reset_mock() + client.fetch_model_providers.reset_mock() - redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1") + second_result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1") - def test_invalidate_plugin_model_providers_cache_ignores_redis_delete_failure(self) -> None: + redis_client.get.assert_called_once_with(generation_key) + redis_client.mget.assert_not_called() + redis_client.setex.assert_not_called() + client.fetch_model_providers.assert_not_called() + assert [provider.provider for provider in second_result] == ["langgenius/openai/openai"] + assert second_result[0] == first_result[0] + assert second_result[0] is not first_result[0] + + def test_invalidate_plugin_model_providers_cache_uses_redis_pipeline(self) -> None: with patch(f"{MODULE}.redis_client") as redis_client: - redis_client.delete.side_effect = RedisError("redis unavailable") + pipe = redis_client.pipeline.return_value from core.plugin.plugin_service import PluginService PluginService.invalidate_plugin_model_providers_cache("tenant-1") - redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1") + redis_client.pipeline.assert_called_once_with(transaction=False) + pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1")) + pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1")) + pipe.execute.assert_called_once_with() + + def test_invalidate_plugin_model_providers_cache_ignores_redis_pipeline_failure(self) -> None: + with patch(f"{MODULE}.redis_client") as redis_client: + pipe = redis_client.pipeline.return_value + pipe.execute.side_effect = RedisError("redis unavailable") + + from core.plugin.plugin_service import PluginService + + PluginService.invalidate_plugin_model_providers_cache("tenant-1") + + redis_client.pipeline.assert_called_once_with(transaction=False) + pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1")) + pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1")) + pipe.execute.assert_called_once_with() + + def test_invalidate_plugin_model_providers_cache_clears_process_local_cache(self) -> None: + with patch(f"{MODULE}.redis_client") as redis_client: + pipe = redis_client.pipeline.return_value + + from core.plugin.plugin_service import PluginService + + PluginService._store_in_memory_plugin_model_providers("tenant-1", 0, [_build_provider_entity()]) + PluginService.invalidate_plugin_model_providers_cache("tenant-1") + + assert PluginService._plugin_model_providers_memory_cache == {} + redis_client.pipeline.assert_called_once_with(transaction=False) + pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1")) + pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1")) + pipe.execute.assert_called_once_with() + + def test_fetch_plugin_model_providers_ignores_stale_process_local_cache_after_generation_bump(self) -> None: + generation_key = _provider_generation_key("tenant-1") + new_cache_key = _provider_cache_key("tenant-1", 1) + with patch(f"{MODULE}.redis_client") as redis_client: + redis_client.get.side_effect = [b"1", b"1"] + redis_client.mget.return_value = [None] + client = Mock() + client.fetch_model_providers.return_value = [_build_plugin_model_provider(provider="anthropic")] + + from core.plugin.plugin_service import PluginService + + PluginService._store_in_memory_plugin_model_providers("tenant-1", 0, [_build_provider_entity()]) + result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client) + + client.fetch_model_providers.assert_called_once_with("tenant-1") + redis_client.get.assert_has_calls([call(generation_key), call(generation_key)]) + redis_client.mget.assert_called_once_with([new_cache_key]) + redis_client.setex.assert_called_once() + assert redis_client.setex.call_args.args[0] == new_cache_key + assert PluginService._plugin_model_providers_memory_cache["tenant-1"][0] == 1 + assert [provider.provider for provider in result] == ["langgenius/anthropic/anthropic"] class TestPluginListEndpointCounts: