mirror of
https://github.com/langgenius/dify.git
synced 2026-06-12 19:53:38 +08:00
perf(api): reduce workflow startup latency for chatflow (#36773)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
632df88228
commit
c4a8d79be9
@ -4,6 +4,7 @@ from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelWithProviderEntity
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
@ -19,6 +20,10 @@ class DifyCredentialsProvider:
|
||||
|
||||
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||
The matching validated provider model is cached alongside the credentials so
|
||||
follow-up workflow startup checks can reuse the earlier provider lookup
|
||||
instead of resolving the same model metadata a second time.
|
||||
|
||||
Because of that cache, a single instance can return stale credentials after
|
||||
the tenant or provider configuration changes (e.g. API key rotation).
|
||||
|
||||
@ -30,6 +35,7 @@ class DifyCredentialsProvider:
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||
provider_model_cache: dict[tuple[str, str], ModelWithProviderEntity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -45,12 +51,21 @@ class DifyCredentialsProvider:
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
self.credentials_cache = {}
|
||||
self.provider_model_cache = {}
|
||||
|
||||
def get_cached_provider_model(self, provider_name: str, model_name: str) -> ModelWithProviderEntity | None:
|
||||
provider_model = self.provider_model_cache.get((provider_name, model_name))
|
||||
if provider_model is None:
|
||||
return None
|
||||
|
||||
return provider_model.model_copy(deep=True)
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
if (provider_name, model_name) in self.credentials_cache:
|
||||
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider_name} does not exist.")
|
||||
@ -65,6 +80,7 @@ class DifyCredentialsProvider:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||
self.provider_model_cache[(provider_name, model_name)] = provider_model.model_copy(deep=True)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -142,13 +158,21 @@ def fetch_model_config(
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
provider_model = None
|
||||
if isinstance(credentials_provider, DifyCredentialsProvider):
|
||||
provider_model = credentials_provider.get_cached_provider_model(
|
||||
provider_name=node_data_model.provider,
|
||||
model_name=node_data_model.name,
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
|
||||
provider_model.raise_for_status()
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if model_schema is None:
|
||||
|
||||
@ -69,6 +69,11 @@ class ProviderConfiguration(BaseModel):
|
||||
nested schema and model lookups reuse the caller scope that was already
|
||||
resolved by the composition layer.
|
||||
|
||||
The ``provider`` field already contains the resolved provider schema that
|
||||
was used to build this configuration. Reuse that schema for nested model
|
||||
lookups instead of refetching the full provider catalog from the runtime on
|
||||
every request-scoped lookup.
|
||||
|
||||
TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
|
||||
"""
|
||||
|
||||
@ -83,15 +88,19 @@ class ProviderConfiguration(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
_bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None)
|
||||
_cached_provider_schema: ProviderEntity | None = PrivateAttr(default=None)
|
||||
_original_provider_configurate_methods: tuple[ConfigurateMethod, ...] = PrivateAttr(default_factory=tuple)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
self._original_provider_configurate_methods = tuple(self.provider.configurate_methods)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in self.provider.configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if list(self._original_provider_configurate_methods) == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if (
|
||||
any(
|
||||
len(quota_configuration.restrict_models) > 0
|
||||
@ -105,6 +114,29 @@ class ProviderConfiguration(BaseModel):
|
||||
def bind_model_runtime(self, model_runtime: ModelRuntime) -> None:
|
||||
"""Attach the already-composed runtime for request-bound call chains."""
|
||||
self._bound_model_runtime = model_runtime
|
||||
self._cached_provider_schema = self.provider
|
||||
|
||||
def _get_original_provider_configurate_methods(self) -> list[ConfigurateMethod]:
|
||||
return list(self._original_provider_configurate_methods)
|
||||
|
||||
def _get_provider_schema(self, *, model_provider_factory: ModelProviderFactory | None = None) -> ProviderEntity:
|
||||
"""Resolve the provider schema lazily while preserving bound-runtime reuse."""
|
||||
if self._cached_provider_schema is None:
|
||||
if self.provider.models:
|
||||
self._cached_provider_schema = self.provider
|
||||
else:
|
||||
provider_factory = model_provider_factory or self.get_model_provider_factory()
|
||||
self._cached_provider_schema = provider_factory.get_provider_schema(provider=self.provider.provider)
|
||||
|
||||
return self._cached_provider_schema
|
||||
|
||||
def _get_model_runtime(self) -> ModelRuntime:
|
||||
"""Return the runtime aligned with this request-scoped configuration."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return self._bound_model_runtime
|
||||
|
||||
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
|
||||
return model_assembly.model_runtime
|
||||
|
||||
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
|
||||
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
|
||||
@ -153,7 +185,6 @@ class ProviderConfiguration(BaseModel):
|
||||
and restrict_model.base_model_name
|
||||
):
|
||||
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
||||
|
||||
return copy_credentials
|
||||
else:
|
||||
credentials = None
|
||||
@ -189,7 +220,6 @@ class ProviderConfiguration(BaseModel):
|
||||
provider=self.provider.provider,
|
||||
credential_type=PluginCredentialType.MODEL,
|
||||
)
|
||||
|
||||
return credentials
|
||||
|
||||
def get_system_configuration_status(self) -> SystemConfigurationStatus | None:
|
||||
@ -1399,8 +1429,13 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
|
||||
if self._bound_model_runtime is not None:
|
||||
model_runtime = self._bound_model_runtime
|
||||
else:
|
||||
model_runtime, _ = self._get_runtime_and_provider_factory()
|
||||
|
||||
provider_schema = self._cached_provider_schema or self.provider
|
||||
|
||||
return create_model_type_instance(
|
||||
runtime=model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
@ -1410,12 +1445,13 @@ class ProviderConfiguration(BaseModel):
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
return model_provider_factory.get_model_schema(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
"""Get model schema with the request-bound runtime and canonical provider id."""
|
||||
model_runtime = self._get_model_runtime()
|
||||
return model_runtime.get_model_schema(
|
||||
provider=self.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
|
||||
@ -1515,8 +1551,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
|
||||
provider_schema = self._get_provider_schema()
|
||||
|
||||
model_types: list[ModelType] = []
|
||||
if model_type:
|
||||
@ -1531,7 +1566,10 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
|
||||
model_types=model_types,
|
||||
provider_schema=provider_schema,
|
||||
model_setting_map=model_setting_map,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
@ -1573,6 +1611,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_types: Sequence[ModelType],
|
||||
provider_schema: ProviderEntity,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
||||
model: str | None = None,
|
||||
) -> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
@ -1587,6 +1626,8 @@ class ProviderConfiguration(BaseModel):
|
||||
for m in provider_schema.models:
|
||||
if m.model_type != model_type:
|
||||
continue
|
||||
if model and m.model != model:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
@ -1608,13 +1649,9 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
for configurate_method in provider_schema.configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
||||
|
||||
original_configurate_methods = self._get_original_provider_configurate_methods()
|
||||
should_use_custom_model = False
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if original_configurate_methods == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
should_use_custom_model = True
|
||||
|
||||
for quota_configuration in self.system_configuration.quota_configurations:
|
||||
@ -1626,11 +1663,12 @@ class ProviderConfiguration(BaseModel):
|
||||
break
|
||||
|
||||
if should_use_custom_model:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [
|
||||
ConfigurateMethod.CUSTOMIZABLE_MODEL
|
||||
]:
|
||||
if original_configurate_methods == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
if model and restrict_model.model != model:
|
||||
continue
|
||||
|
||||
copy_credentials = (
|
||||
self.system_configuration.credentials.copy()
|
||||
if self.system_configuration.credentials
|
||||
@ -1680,11 +1718,11 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
# if llm name not in restricted llm list, remove it
|
||||
restrict_model_names = [rm.model for rm in restrict_models]
|
||||
for model in provider_models:
|
||||
if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
|
||||
model.status = ModelStatus.NO_PERMISSION
|
||||
for provider_model in provider_models:
|
||||
if provider_model.model_type == ModelType.LLM and provider_model.model not in restrict_model_names:
|
||||
provider_model.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
model.status = ModelStatus.QUOTA_EXCEEDED
|
||||
provider_model.status = ModelStatus.QUOTA_EXCEEDED
|
||||
|
||||
return provider_models
|
||||
|
||||
@ -1709,6 +1747,13 @@ class ProviderConfiguration(BaseModel):
|
||||
if self.custom_configuration.provider:
|
||||
credentials = self.custom_configuration.provider.credentials
|
||||
|
||||
requested_predefined_model = False
|
||||
if model:
|
||||
requested_predefined_model = any(
|
||||
predefined_model.model_type in model_types and predefined_model.model == model
|
||||
for predefined_model in provider_schema.models
|
||||
)
|
||||
|
||||
for model_type in model_types:
|
||||
if model_type not in self.provider.supported_model_types:
|
||||
continue
|
||||
@ -1716,6 +1761,8 @@ class ProviderConfiguration(BaseModel):
|
||||
for m in provider_schema.models:
|
||||
if m.model_type != model_type:
|
||||
continue
|
||||
if requested_predefined_model and model and m.model != model:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
load_balancing_enabled = False
|
||||
|
||||
@ -13,8 +13,10 @@ metadata.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_type
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
from redis import RedisError
|
||||
@ -64,6 +66,8 @@ _provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list
|
||||
|
||||
|
||||
class PluginService:
|
||||
_plugin_model_providers_memory_cache: ClassVar[dict[str, tuple[int, float, tuple[ProviderEntity, ...]]]] = {}
|
||||
|
||||
class LatestPluginCache(BaseModel):
|
||||
plugin_id: str
|
||||
version: str
|
||||
@ -75,14 +79,22 @@ class PluginService:
|
||||
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
|
||||
REDIS_TTL = 60 * 5 # 5 minutes
|
||||
PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:"
|
||||
PLUGIN_MODEL_PROVIDERS_GENERATION_REDIS_KEY_PREFIX = "plugin_model_providers_generation:tenant_id:"
|
||||
PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed)
|
||||
# Mirror the detail-panel endpoint query size so list reconciliation and
|
||||
# the visible endpoint drawer exercise the same daemon pagination path.
|
||||
ENDPOINT_RECONCILIATION_PAGE_SIZE = 100
|
||||
|
||||
@classmethod
|
||||
def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str:
|
||||
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
|
||||
def _get_plugin_model_providers_cache_key(cls, tenant_id: str, generation: int | None = None) -> str:
|
||||
if generation is None:
|
||||
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
|
||||
|
||||
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}:generation:{generation}"
|
||||
|
||||
@classmethod
|
||||
def _get_plugin_model_providers_generation_cache_key(cls, tenant_id: str) -> str:
|
||||
return f"{cls.PLUGIN_MODEL_PROVIDERS_GENERATION_REDIS_KEY_PREFIX}{tenant_id}"
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str:
|
||||
@ -115,29 +127,133 @@ class PluginService:
|
||||
return declaration
|
||||
|
||||
@classmethod
|
||||
def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None:
|
||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
|
||||
def _copy_provider_entities(cls, providers: Sequence[ProviderEntity]) -> tuple[ProviderEntity, ...]:
|
||||
return tuple(provider.model_copy(deep=True) for provider in providers)
|
||||
|
||||
@classmethod
|
||||
def _load_plugin_model_providers_generation(cls, tenant_id: str) -> int | None:
|
||||
cache_key = cls._get_plugin_model_providers_generation_cache_key(tenant_id)
|
||||
try:
|
||||
cached_providers = redis_client.get(cache_key)
|
||||
cached_generation = redis_client.get(cache_key)
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning("Failed to read plugin model provider generation for tenant %s.", tenant_id, exc_info=True)
|
||||
return None
|
||||
|
||||
if cached_generation is None:
|
||||
return 0
|
||||
|
||||
try:
|
||||
return int(cached_generation)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid plugin model provider generation for tenant %s; deleting cache marker.",
|
||||
tenant_id,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning(
|
||||
"Failed to delete invalid plugin model provider generation for tenant %s.",
|
||||
tenant_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _load_in_memory_plugin_model_providers(
|
||||
cls, memory_cache_key: str, generation: int
|
||||
) -> tuple[ProviderEntity, ...] | None:
|
||||
cached_entry = cls._plugin_model_providers_memory_cache.get(memory_cache_key)
|
||||
if cached_entry is None:
|
||||
return None
|
||||
|
||||
cached_generation, expires_at, providers = cached_entry
|
||||
if cached_generation != generation or time.monotonic() >= expires_at:
|
||||
cls._plugin_model_providers_memory_cache.pop(memory_cache_key, None)
|
||||
return None
|
||||
|
||||
return cls._copy_provider_entities(providers)
|
||||
|
||||
@classmethod
|
||||
def _store_in_memory_plugin_model_providers(
|
||||
cls, memory_cache_key: str, generation: int, providers: Sequence[ProviderEntity]
|
||||
) -> None:
|
||||
ttl = dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL
|
||||
if ttl <= 0:
|
||||
cls._plugin_model_providers_memory_cache.pop(memory_cache_key, None)
|
||||
return
|
||||
|
||||
cls._plugin_model_providers_memory_cache[memory_cache_key] = (
|
||||
generation,
|
||||
time.monotonic() + ttl,
|
||||
cls._copy_provider_entities(providers),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_cached_plugin_model_providers(
|
||||
cls, tenant_id: str, *, client: PluginModelClient | None = None
|
||||
) -> tuple[ProviderEntity, ...] | None:
|
||||
generation = cls._load_plugin_model_providers_generation(tenant_id)
|
||||
if generation is not None:
|
||||
in_memory_cached_providers = cls._load_in_memory_plugin_model_providers(tenant_id, generation)
|
||||
if in_memory_cached_providers is not None:
|
||||
return in_memory_cached_providers
|
||||
|
||||
cache_keys = []
|
||||
if generation is not None:
|
||||
cache_keys.append(cls._get_plugin_model_providers_cache_key(tenant_id, generation))
|
||||
if generation == 0:
|
||||
cache_keys.append(cls._get_plugin_model_providers_cache_key(tenant_id))
|
||||
|
||||
if not cache_keys:
|
||||
return None
|
||||
|
||||
try:
|
||||
cached_provider_entries = redis_client.mget(cache_keys)
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True)
|
||||
return None
|
||||
|
||||
if not cached_providers:
|
||||
if len(cached_provider_entries) != len(cache_keys):
|
||||
logger.warning(
|
||||
"Unexpected cached plugin model providers response size for tenant %s.",
|
||||
tenant_id,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return tuple(_provider_entities_adapter.validate_json(cached_providers))
|
||||
except (TypeError, ValueError, ValidationError):
|
||||
logger.warning(
|
||||
"Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True
|
||||
)
|
||||
cls.invalidate_plugin_model_providers_cache(tenant_id)
|
||||
return None
|
||||
for cache_key, cached_providers in zip(cache_keys, cached_provider_entries):
|
||||
if not cached_providers:
|
||||
continue
|
||||
|
||||
try:
|
||||
providers = tuple(_provider_entities_adapter.validate_json(cached_providers))
|
||||
if generation is not None:
|
||||
cls._store_in_memory_plugin_model_providers(tenant_id, generation, providers)
|
||||
return providers
|
||||
except (TypeError, ValueError, ValidationError):
|
||||
logger.warning(
|
||||
"Invalid cached plugin model providers for tenant %s; deleting cache key %s.",
|
||||
tenant_id,
|
||||
cache_key,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
redis_client.delete(cache_key)
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning(
|
||||
"Failed to delete invalid cached plugin model providers for tenant %s.",
|
||||
tenant_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None:
|
||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
|
||||
def _store_cached_plugin_model_providers(
|
||||
cls, tenant_id: str, generation: int, providers: Sequence[ProviderEntity]
|
||||
) -> None:
|
||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id, generation)
|
||||
try:
|
||||
payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8")
|
||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload)
|
||||
@ -146,9 +262,15 @@ class PluginService:
|
||||
|
||||
@classmethod
|
||||
def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None:
|
||||
"""Delete the tenant-scoped plugin model provider list cache."""
|
||||
"""Invalidate tenant-scoped provider metadata across Redis and worker-local mirrors."""
|
||||
cls._plugin_model_providers_memory_cache.pop(tenant_id, None)
|
||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
|
||||
generation_key = cls._get_plugin_model_providers_generation_cache_key(tenant_id)
|
||||
try:
|
||||
redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id))
|
||||
pipe = redis_client.pipeline(transaction=False)
|
||||
pipe.delete(cache_key)
|
||||
pipe.incr(generation_key)
|
||||
pipe.execute()
|
||||
except (RedisError, RuntimeError):
|
||||
logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True)
|
||||
|
||||
@ -163,7 +285,7 @@ class PluginService:
|
||||
are intentionally owned by this service so tenant isolation and cache
|
||||
expiry are handled in one place.
|
||||
"""
|
||||
cached_providers = cls._load_cached_plugin_model_providers(tenant_id)
|
||||
cached_providers = cls._load_cached_plugin_model_providers(tenant_id, client=client)
|
||||
if cached_providers is not None:
|
||||
return cached_providers
|
||||
|
||||
@ -171,7 +293,12 @@ class PluginService:
|
||||
providers = tuple(
|
||||
cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id)
|
||||
)
|
||||
cls._store_cached_plugin_model_providers(tenant_id, providers)
|
||||
if not providers:
|
||||
return providers
|
||||
generation = cls._load_plugin_model_providers_generation(tenant_id)
|
||||
if generation is not None:
|
||||
cls._store_in_memory_plugin_model_providers(tenant_id, generation, providers)
|
||||
cls._store_cached_plugin_model_providers(tenant_id, generation, providers)
|
||||
return providers
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -433,10 +433,9 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = configuration.provider
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
mock_assembly = Mock()
|
||||
mock_assembly.model_runtime = Mock()
|
||||
mock_assembly.model_runtime.get_model_schema.return_value = mock_schema
|
||||
mock_assembly.model_provider_factory = mock_factory
|
||||
|
||||
with (
|
||||
@ -455,13 +454,12 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_assembly_builder.call_count == 2
|
||||
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
|
||||
mock_model_builder.assert_called_once_with(
|
||||
runtime=mock_assembly.model_runtime,
|
||||
provider_schema=configuration.provider,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
mock_factory.get_model_schema.assert_called_once_with(
|
||||
mock_assembly.model_runtime.get_model_schema.assert_called_once_with(
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
@ -472,18 +470,13 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
bound_runtime = Mock()
|
||||
bound_runtime.get_model_schema.return_value = _build_ai_model("gpt-4o")
|
||||
configuration.bind_model_runtime(bound_runtime)
|
||||
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = configuration.provider
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
|
||||
) as mock_factory_cls,
|
||||
patch("core.entities.provider_configuration.ModelProviderFactory") as mock_factory_cls,
|
||||
patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder,
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_model_type_instance",
|
||||
@ -494,16 +487,20 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_cls.call_count == 2
|
||||
mock_factory_cls.assert_called_with(runtime=bound_runtime)
|
||||
assert model_schema == bound_runtime.get_model_schema.return_value
|
||||
mock_factory_cls.assert_not_called()
|
||||
mock_assembly_builder.assert_not_called()
|
||||
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
|
||||
mock_model_builder.assert_called_once_with(
|
||||
runtime=bound_runtime,
|
||||
provider_schema=configuration.provider,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
bound_runtime.get_model_schema.assert_called_once_with(
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-4o",
|
||||
credentials={"api_key": "x"},
|
||||
)
|
||||
|
||||
|
||||
def test_get_provider_model_returns_none_when_model_not_found() -> None:
|
||||
@ -544,6 +541,99 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
|
||||
assert [model.model for model in active_models] == ["b-model"]
|
||||
|
||||
|
||||
def test_get_provider_models_system_filters_requested_model() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
provider_schema = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=[_build_ai_model("a-model"), _build_ai_model("target-model"), _build_ai_model("b-model")],
|
||||
)
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = provider_schema
|
||||
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
):
|
||||
models = configuration.get_provider_models(
|
||||
model_type=ModelType.LLM,
|
||||
only_active=False,
|
||||
model="target-model",
|
||||
)
|
||||
|
||||
assert [model.model for model in models] == ["target-model"]
|
||||
|
||||
|
||||
def test_get_provider_models_system_customizable_filters_requested_restricted_model() -> None:
|
||||
provider = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL],
|
||||
)
|
||||
system_configuration = SystemConfiguration(
|
||||
enabled=True,
|
||||
credentials={"api_key": "test-key"},
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
QuotaConfiguration(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=1_000,
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
restrict_models=[
|
||||
RestrictModel(model="target-model", base_model_name="base-model", model_type=ModelType.LLM),
|
||||
RestrictModel(model="other-model", base_model_name="base-model", model_type=ModelType.LLM),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
provider_schema = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=[],
|
||||
)
|
||||
mock_factory = Mock()
|
||||
mock_factory.get_provider_schema.return_value = provider_schema
|
||||
|
||||
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
|
||||
configuration = ProviderConfiguration(
|
||||
tenant_id="tenant-1",
|
||||
provider=provider,
|
||||
preferred_provider_type=ProviderType.SYSTEM,
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=system_configuration,
|
||||
custom_configuration=CustomConfiguration(provider=None, models=[]),
|
||||
model_settings=[],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
|
||||
),
|
||||
patch.object(
|
||||
ProviderConfiguration,
|
||||
"get_model_schema",
|
||||
side_effect=lambda *args, **kwargs: _build_ai_model(kwargs["model"]),
|
||||
) as mock_get_model_schema,
|
||||
):
|
||||
models = configuration.get_provider_models(
|
||||
model_type=ModelType.LLM,
|
||||
only_active=False,
|
||||
model="target-model",
|
||||
)
|
||||
|
||||
assert [model.model for model in models] == ["target-model"]
|
||||
mock_get_model_schema.assert_called_once()
|
||||
assert mock_get_model_schema.call_args.kwargs["model"] == "target-model"
|
||||
|
||||
|
||||
def test_get_custom_provider_models_sets_status_for_removed_credentials_and_invalid_lb_configs() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.using_provider_type = ProviderType.CUSTOM
|
||||
@ -611,6 +701,48 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
|
||||
assert invalid_lb_map["custom-model"] is True
|
||||
|
||||
|
||||
def test_get_custom_provider_models_filters_requested_base_model() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
configuration.using_provider_type = ProviderType.CUSTOM
|
||||
configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"})
|
||||
provider_schema = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=[_build_ai_model("base-model"), _build_ai_model("target-model")],
|
||||
)
|
||||
|
||||
models = configuration._get_custom_provider_models(
|
||||
model_types=[ModelType.LLM],
|
||||
provider_schema=provider_schema,
|
||||
model_setting_map={},
|
||||
model="target-model",
|
||||
)
|
||||
|
||||
assert [model.model for model in models] == ["target-model"]
|
||||
|
||||
|
||||
def test_get_provider_models_reuses_cached_provider_schema() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
provider_schema = ProviderEntity(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
models=[_build_ai_model("a-model"), _build_ai_model("b-model")],
|
||||
)
|
||||
configuration.provider = provider_schema
|
||||
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_assembly",
|
||||
) as mock_assembly_builder:
|
||||
configuration.get_provider_models(model_type=ModelType.LLM, model="a-model")
|
||||
configuration.get_provider_models(model_type=ModelType.LLM, model="b-model")
|
||||
|
||||
mock_assembly_builder.assert_not_called()
|
||||
|
||||
|
||||
def test_validator_adds_predefined_model_for_customizable_provider_with_restrictions() -> None:
|
||||
provider = ProviderEntity(
|
||||
provider="openai",
|
||||
@ -1402,25 +1534,22 @@ def test_system_and_custom_provider_model_helpers_cover_remaining_skip_paths() -
|
||||
return _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING)
|
||||
return _build_ai_model("target")
|
||||
|
||||
with patch(
|
||||
"core.entities.provider_configuration.original_provider_configurate_methods",
|
||||
{"openai": [ConfigurateMethod.CUSTOMIZABLE_MODEL]},
|
||||
):
|
||||
with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema):
|
||||
system_models = configuration._get_system_provider_models(
|
||||
model_types=[ModelType.LLM],
|
||||
provider_schema=provider_schema,
|
||||
model_setting_map={
|
||||
ModelType.LLM: {
|
||||
"target": ModelSettings(
|
||||
model="target",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=False,
|
||||
load_balancing_configs=[],
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
configuration._original_provider_configurate_methods = (ConfigurateMethod.CUSTOMIZABLE_MODEL,)
|
||||
with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema):
|
||||
system_models = configuration._get_system_provider_models(
|
||||
model_types=[ModelType.LLM],
|
||||
provider_schema=provider_schema,
|
||||
model_setting_map={
|
||||
ModelType.LLM: {
|
||||
"target": ModelSettings(
|
||||
model="target",
|
||||
model_type=ModelType.LLM,
|
||||
enabled=False,
|
||||
load_balancing_configs=[],
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
assert any(model.model == "target" and model.status == ModelStatus.DISABLED for model in system_models)
|
||||
|
||||
configuration.using_provider_type = ProviderType.CUSTOM
|
||||
|
||||
@ -28,6 +28,9 @@ class _FakeRedis:
|
||||
def get(self, key: str) -> str | None:
|
||||
return self._values.get(key)
|
||||
|
||||
def mget(self, keys: list[str]) -> list[str | None]:
|
||||
return [self.get(key) for key in keys]
|
||||
|
||||
def setex(self, key: str, ttl: int, value: str) -> None:
|
||||
self._values[key] = value
|
||||
self.setex_calls.append((key, ttl, value))
|
||||
@ -36,6 +39,13 @@ class _FakeRedis:
|
||||
self._values.pop(key, None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_plugin_model_provider_memory_cache() -> None:
|
||||
PluginService._plugin_model_providers_memory_cache.clear()
|
||||
yield
|
||||
PluginService._plugin_model_providers_memory_cache.clear()
|
||||
|
||||
|
||||
def _build_model_schema() -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model="gpt-4o-mini",
|
||||
@ -329,6 +339,7 @@ class TestPluginModelRuntime:
|
||||
"redis_client",
|
||||
SimpleNamespace(
|
||||
get=Mock(return_value=None),
|
||||
mget=Mock(return_value=[None, None]),
|
||||
delete=Mock(),
|
||||
setex=Mock(),
|
||||
),
|
||||
|
||||
@ -345,6 +345,62 @@ def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_confi
|
||||
provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
def test_fetch_model_config_reuses_validated_provider_model_from_dify_credentials_provider(
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
):
|
||||
mock_provider_manager = mock.MagicMock()
|
||||
mock_configurations = mock.MagicMock()
|
||||
mock_provider_configuration = mock.MagicMock()
|
||||
mock_provider_model = mock.MagicMock()
|
||||
mock_model_factory = mock.MagicMock(spec=DifyModelFactory)
|
||||
|
||||
mock_configurations.get.return_value = mock_provider_configuration
|
||||
mock_provider_configuration.get_provider_model.return_value = mock_provider_model
|
||||
mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"}
|
||||
mock_provider_manager.get_configurations.return_value = mock_configurations
|
||||
|
||||
run_context = DifyRunContext(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
credentials_provider = DifyCredentialsProvider(
|
||||
run_context=run_context,
|
||||
provider_manager=mock_provider_manager,
|
||||
)
|
||||
|
||||
model_instance = mock.MagicMock(
|
||||
model_type_instance=model_config.provider_model_bundle.model_type_instance,
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
)
|
||||
mock_model_factory.init_model_instance.return_value = model_instance
|
||||
|
||||
with mock.patch.object(
|
||||
model_instance.model_type_instance.__class__,
|
||||
"get_model_schema",
|
||||
return_value=model_config.model_schema,
|
||||
autospec=True,
|
||||
):
|
||||
fetch_model_config(
|
||||
node_data_model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
completion_params={},
|
||||
),
|
||||
credentials_provider=credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
)
|
||||
|
||||
mock_provider_configuration.get_provider_model.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
def test_dify_model_access_adapters_call_managers():
|
||||
mock_provider_manager = mock.MagicMock()
|
||||
mock_model_manager = mock.MagicMock()
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, call, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
from redis import RedisError
|
||||
|
||||
@ -13,6 +14,15 @@ from graphon.model_runtime.entities.provider_entities import ConfigurateMethod,
|
||||
MODULE = "core.plugin.plugin_service"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_plugin_model_provider_memory_cache() -> None:
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService._plugin_model_providers_memory_cache.clear()
|
||||
yield
|
||||
PluginService._plugin_model_providers_memory_cache.clear()
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.execute = Mock()
|
||||
@ -68,6 +78,17 @@ def _build_install_task(*, task_id: str = "task-1", status: PluginInstallTaskSta
|
||||
)
|
||||
|
||||
|
||||
def _provider_cache_key(tenant_id: str, generation: int | None = None) -> str:
|
||||
if generation is None:
|
||||
return f"plugin_model_providers:tenant_id:{tenant_id}"
|
||||
|
||||
return f"plugin_model_providers:tenant_id:{tenant_id}:generation:{generation}"
|
||||
|
||||
|
||||
def _provider_generation_key(tenant_id: str) -> str:
|
||||
return f"plugin_model_providers_generation:tenant_id:{tenant_id}"
|
||||
|
||||
|
||||
class TestFetchLatestPluginVersion:
|
||||
def test_skips_marketplace_fetch_when_disabled(self) -> None:
|
||||
"""Cache misses stay None; marketplace is never called when disabled."""
|
||||
@ -120,9 +141,13 @@ class TestPluginModelProviderCache:
|
||||
"""A valid tenant cache entry is reused across runtime calls without plugin daemon access."""
|
||||
cached_provider = _build_provider_entity()
|
||||
cached_payload = TypeAdapter(list[ProviderEntity]).dump_json([cached_provider]).decode("utf-8")
|
||||
generation_key = _provider_generation_key("tenant-1")
|
||||
cache_key = _provider_cache_key("tenant-1", 0)
|
||||
legacy_cache_key = _provider_cache_key("tenant-1")
|
||||
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.get.return_value = cached_payload
|
||||
redis_client.get.return_value = None
|
||||
redis_client.mget.return_value = [cached_payload, None]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
@ -132,14 +157,20 @@ class TestPluginModelProviderCache:
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
client.fetch_model_providers.assert_not_called()
|
||||
redis_client.setex.assert_not_called()
|
||||
redis_client.get.assert_called_once_with(generation_key)
|
||||
redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key])
|
||||
|
||||
def test_fetch_plugin_model_providers_deletes_invalid_cache_and_refetches(self) -> None:
|
||||
"""Invalid cache payloads are tenant-scoped invalidated before falling back to the daemon."""
|
||||
"""Invalid generation-scoped cache payloads are removed before falling back to the daemon."""
|
||||
generation_key = _provider_generation_key("tenant-1")
|
||||
cache_key = _provider_cache_key("tenant-1", 0)
|
||||
legacy_cache_key = _provider_cache_key("tenant-1")
|
||||
with (
|
||||
patch(f"{MODULE}.redis_client") as redis_client,
|
||||
patch(f"{MODULE}.dify_config") as mock_config,
|
||||
):
|
||||
redis_client.get.return_value = "not-json"
|
||||
redis_client.get.side_effect = [None, None]
|
||||
redis_client.mget.return_value = ["not-json", None]
|
||||
mock_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL = 86400
|
||||
client = Mock()
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
@ -148,12 +179,13 @@ class TestPluginModelProviderCache:
|
||||
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
|
||||
|
||||
cache_key = "plugin_model_providers:tenant_id:tenant-1"
|
||||
redis_client.delete.assert_called_once_with(cache_key)
|
||||
redis_client.setex.assert_called_once()
|
||||
assert redis_client.setex.call_args.args[0] == cache_key
|
||||
assert redis_client.setex.call_args.args[1] == 86400
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
redis_client.get.assert_has_calls([call(generation_key), call(generation_key)])
|
||||
redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key])
|
||||
|
||||
def test_fetch_plugin_model_providers_refetches_when_cache_read_fails(self) -> None:
|
||||
"""Redis read failures do not block provider discovery for the tenant."""
|
||||
@ -169,10 +201,29 @@ class TestPluginModelProviderCache:
|
||||
client.fetch_model_providers.assert_called_once_with("tenant-1")
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
|
||||
def test_fetch_plugin_model_providers_refetches_when_cached_payload_batch_read_fails(self) -> None:
|
||||
"""Redis mget failures do not block provider discovery for the tenant."""
|
||||
cache_key = _provider_cache_key("tenant-1", 0)
|
||||
legacy_cache_key = _provider_cache_key("tenant-1")
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.get.return_value = None
|
||||
redis_client.mget.side_effect = RedisError("redis unavailable")
|
||||
client = Mock()
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
|
||||
|
||||
client.fetch_model_providers.assert_called_once_with("tenant-1")
|
||||
redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key])
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
|
||||
def test_fetch_plugin_model_providers_returns_fresh_result_when_cache_write_fails(self) -> None:
|
||||
"""Redis write failures are non-fatal after fresh provider data has been fetched."""
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.get.return_value = None
|
||||
redis_client.mget.return_value = [None, None]
|
||||
redis_client.setex.side_effect = RedisError("redis unavailable")
|
||||
client = Mock()
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
@ -191,6 +242,7 @@ class TestPluginModelProviderCache:
|
||||
patch(f"{MODULE}.PluginModelClient") as client_cls,
|
||||
):
|
||||
redis_client.get.return_value = None
|
||||
redis_client.mget.return_value = [None, None]
|
||||
client = client_cls.return_value
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
|
||||
@ -202,23 +254,98 @@ class TestPluginModelProviderCache:
|
||||
client.fetch_model_providers.assert_called_once_with("tenant-1")
|
||||
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
|
||||
|
||||
def test_invalidate_plugin_model_providers_cache_uses_tenant_cache_key(self) -> None:
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
def test_fetch_plugin_model_providers_reuses_process_local_cache(self) -> None:
|
||||
generation_key = _provider_generation_key("tenant-1")
|
||||
with (
|
||||
patch(f"{MODULE}.redis_client") as redis_client,
|
||||
patch(f"{MODULE}.PluginModelClient") as client_cls,
|
||||
):
|
||||
redis_client.get.side_effect = [None, None, None]
|
||||
redis_client.mget.return_value = [None, None]
|
||||
client = client_cls.return_value
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
|
||||
first_result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1")
|
||||
redis_client.get.reset_mock()
|
||||
redis_client.mget.reset_mock()
|
||||
redis_client.setex.reset_mock()
|
||||
client.fetch_model_providers.reset_mock()
|
||||
|
||||
redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1")
|
||||
second_result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1")
|
||||
|
||||
def test_invalidate_plugin_model_providers_cache_ignores_redis_delete_failure(self) -> None:
|
||||
redis_client.get.assert_called_once_with(generation_key)
|
||||
redis_client.mget.assert_not_called()
|
||||
redis_client.setex.assert_not_called()
|
||||
client.fetch_model_providers.assert_not_called()
|
||||
assert [provider.provider for provider in second_result] == ["langgenius/openai/openai"]
|
||||
assert second_result[0] == first_result[0]
|
||||
assert second_result[0] is not first_result[0]
|
||||
|
||||
def test_invalidate_plugin_model_providers_cache_uses_redis_pipeline(self) -> None:
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.delete.side_effect = RedisError("redis unavailable")
|
||||
pipe = redis_client.pipeline.return_value
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
|
||||
|
||||
redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1")
|
||||
redis_client.pipeline.assert_called_once_with(transaction=False)
|
||||
pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1"))
|
||||
pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1"))
|
||||
pipe.execute.assert_called_once_with()
|
||||
|
||||
def test_invalidate_plugin_model_providers_cache_ignores_redis_pipeline_failure(self) -> None:
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
pipe = redis_client.pipeline.return_value
|
||||
pipe.execute.side_effect = RedisError("redis unavailable")
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
|
||||
|
||||
redis_client.pipeline.assert_called_once_with(transaction=False)
|
||||
pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1"))
|
||||
pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1"))
|
||||
pipe.execute.assert_called_once_with()
|
||||
|
||||
def test_invalidate_plugin_model_providers_cache_clears_process_local_cache(self) -> None:
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
pipe = redis_client.pipeline.return_value
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService._store_in_memory_plugin_model_providers("tenant-1", 0, [_build_provider_entity()])
|
||||
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
|
||||
|
||||
assert PluginService._plugin_model_providers_memory_cache == {}
|
||||
redis_client.pipeline.assert_called_once_with(transaction=False)
|
||||
pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1"))
|
||||
pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1"))
|
||||
pipe.execute.assert_called_once_with()
|
||||
|
||||
def test_fetch_plugin_model_providers_ignores_stale_process_local_cache_after_generation_bump(self) -> None:
|
||||
generation_key = _provider_generation_key("tenant-1")
|
||||
new_cache_key = _provider_cache_key("tenant-1", 1)
|
||||
with patch(f"{MODULE}.redis_client") as redis_client:
|
||||
redis_client.get.side_effect = [b"1", b"1"]
|
||||
redis_client.mget.return_value = [None]
|
||||
client = Mock()
|
||||
client.fetch_model_providers.return_value = [_build_plugin_model_provider(provider="anthropic")]
|
||||
|
||||
from core.plugin.plugin_service import PluginService
|
||||
|
||||
PluginService._store_in_memory_plugin_model_providers("tenant-1", 0, [_build_provider_entity()])
|
||||
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
|
||||
|
||||
client.fetch_model_providers.assert_called_once_with("tenant-1")
|
||||
redis_client.get.assert_has_calls([call(generation_key), call(generation_key)])
|
||||
redis_client.mget.assert_called_once_with([new_cache_key])
|
||||
redis_client.setex.assert_called_once()
|
||||
assert redis_client.setex.call_args.args[0] == new_cache_key
|
||||
assert PluginService._plugin_model_providers_memory_cache["tenant-1"][0] == 1
|
||||
assert [provider.provider for provider in result] == ["langgenius/anthropic/anthropic"]
|
||||
|
||||
|
||||
class TestPluginListEndpointCounts:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user