perf(api): reduce workflow startup latency for chatflow (#36773)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
呆萌闷油瓶 2026-06-11 15:05:35 +08:00 committed by GitHub
parent 632df88228
commit c4a8d79be9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 620 additions and 99 deletions

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelWithProviderEntity
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance, ModelManager
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
@ -19,6 +20,10 @@ class DifyCredentialsProvider:
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
The matching validated provider model is cached alongside the credentials so
follow-up workflow startup checks can reuse the earlier provider lookup
instead of resolving the same model metadata a second time.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
@ -30,6 +35,7 @@ class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
provider_model_cache: dict[tuple[str, str], ModelWithProviderEntity]
def __init__(
self,
@ -45,12 +51,21 @@ class DifyCredentialsProvider:
)
self.provider_manager = provider_manager
self.credentials_cache = {}
self.provider_model_cache = {}
def get_cached_provider_model(self, provider_name: str, model_name: str) -> ModelWithProviderEntity | None:
provider_model = self.provider_model_cache.get((provider_name, model_name))
if provider_model is None:
return None
return provider_model.model_copy(deep=True)
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
raise ValueError(f"Provider {provider_name} does not exist.")
@ -65,6 +80,7 @@ class DifyCredentialsProvider:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
self.provider_model_cache[(provider_name, model_name)] = provider_model.model_copy(deep=True)
return credentials
@ -142,13 +158,21 @@ def fetch_model_config(
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
provider_model = None
if isinstance(credentials_provider, DifyCredentialsProvider):
provider_model = credentials_provider.get_cached_provider_model(
provider_name=node_data_model.provider,
model_name=node_data_model.name,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
provider_model.raise_for_status()
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} does not exist.")
provider_model.raise_for_status()
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if model_schema is None:

View File

@ -69,6 +69,11 @@ class ProviderConfiguration(BaseModel):
nested schema and model lookups reuse the caller scope that was already
resolved by the composition layer.
The ``provider`` field already contains the resolved provider schema that
was used to build this configuration. Reuse that schema for nested model
lookups instead of refetching the full provider catalog from the runtime on
every request-scoped lookup.
TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
"""
@ -83,15 +88,19 @@ class ProviderConfiguration(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
_bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None)
_cached_provider_schema: ProviderEntity | None = PrivateAttr(default=None)
_original_provider_configurate_methods: tuple[ConfigurateMethod, ...] = PrivateAttr(default_factory=tuple)
@model_validator(mode="after")
def _(self):
self._original_provider_configurate_methods = tuple(self.provider.configurate_methods)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if list(self._original_provider_configurate_methods) == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (
any(
len(quota_configuration.restrict_models) > 0
@ -105,6 +114,29 @@ class ProviderConfiguration(BaseModel):
def bind_model_runtime(self, model_runtime: ModelRuntime) -> None:
"""Attach the already-composed runtime for request-bound call chains."""
self._bound_model_runtime = model_runtime
self._cached_provider_schema = self.provider
def _get_original_provider_configurate_methods(self) -> list[ConfigurateMethod]:
return list(self._original_provider_configurate_methods)
def _get_provider_schema(self, *, model_provider_factory: ModelProviderFactory | None = None) -> ProviderEntity:
"""Resolve the provider schema lazily while preserving bound-runtime reuse."""
if self._cached_provider_schema is None:
if self.provider.models:
self._cached_provider_schema = self.provider
else:
provider_factory = model_provider_factory or self.get_model_provider_factory()
self._cached_provider_schema = provider_factory.get_provider_schema(provider=self.provider.provider)
return self._cached_provider_schema
def _get_model_runtime(self) -> ModelRuntime:
"""Return the runtime aligned with this request-scoped configuration."""
if self._bound_model_runtime is not None:
return self._bound_model_runtime
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
return model_assembly.model_runtime
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
@ -153,7 +185,6 @@ class ProviderConfiguration(BaseModel):
and restrict_model.base_model_name
):
copy_credentials["base_model_name"] = restrict_model.base_model_name
return copy_credentials
else:
credentials = None
@ -189,7 +220,6 @@ class ProviderConfiguration(BaseModel):
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
return credentials
def get_system_configuration_status(self) -> SystemConfigurationStatus | None:
@ -1399,8 +1429,13 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
if self._bound_model_runtime is not None:
model_runtime = self._bound_model_runtime
else:
model_runtime, _ = self._get_runtime_and_provider_factory()
provider_schema = self._cached_provider_schema or self.provider
return create_model_type_instance(
runtime=model_runtime,
provider_schema=provider_schema,
@ -1410,12 +1445,13 @@ class ProviderConfiguration(BaseModel):
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
"""
Get model schema
"""
model_provider_factory = self.get_model_provider_factory()
return model_provider_factory.get_model_schema(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
"""Get model schema with the request-bound runtime and canonical provider id."""
model_runtime = self._get_model_runtime()
return model_runtime.get_model_schema(
provider=self.provider.provider,
model_type=model_type,
model=model,
credentials=credentials or {},
)
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
@ -1515,8 +1551,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_factory = self.get_model_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
provider_schema = self._get_provider_schema()
model_types: list[ModelType] = []
if model_type:
@ -1531,7 +1566,10 @@ class ProviderConfiguration(BaseModel):
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
model_types=model_types,
provider_schema=provider_schema,
model_setting_map=model_setting_map,
model=model,
)
else:
provider_models = self._get_custom_provider_models(
@ -1573,6 +1611,7 @@ class ProviderConfiguration(BaseModel):
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
model: str | None = None,
) -> list[ModelWithProviderEntity]:
"""
Get system provider models.
@ -1587,6 +1626,8 @@ class ProviderConfiguration(BaseModel):
for m in provider_schema.models:
if m.model_type != model_type:
continue
if model and m.model != model:
continue
status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
@ -1608,13 +1649,9 @@ class ProviderConfiguration(BaseModel):
)
)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_schema.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
original_configurate_methods = self._get_original_provider_configurate_methods()
should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if original_configurate_methods == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True
for quota_configuration in self.system_configuration.quota_configurations:
@ -1626,11 +1663,12 @@ class ProviderConfiguration(BaseModel):
break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [
ConfigurateMethod.CUSTOMIZABLE_MODEL
]:
if original_configurate_methods == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
if model and restrict_model.model != model:
continue
copy_credentials = (
self.system_configuration.credentials.copy()
if self.system_configuration.credentials
@ -1680,11 +1718,11 @@ class ProviderConfiguration(BaseModel):
# if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for model in provider_models:
if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
model.status = ModelStatus.NO_PERMISSION
for provider_model in provider_models:
if provider_model.model_type == ModelType.LLM and provider_model.model not in restrict_model_names:
provider_model.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
model.status = ModelStatus.QUOTA_EXCEEDED
provider_model.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
@ -1709,6 +1747,13 @@ class ProviderConfiguration(BaseModel):
if self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
requested_predefined_model = False
if model:
requested_predefined_model = any(
predefined_model.model_type in model_types and predefined_model.model == model
for predefined_model in provider_schema.models
)
for model_type in model_types:
if model_type not in self.provider.supported_model_types:
continue
@ -1716,6 +1761,8 @@ class ProviderConfiguration(BaseModel):
for m in provider_schema.models:
if m.model_type != model_type:
continue
if requested_predefined_model and model and m.model != model:
continue
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False

View File

@ -13,8 +13,10 @@ metadata.
"""
import logging
import time
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from typing import ClassVar
from pydantic import BaseModel, TypeAdapter, ValidationError
from redis import RedisError
@ -64,6 +66,8 @@ _provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list
class PluginService:
_plugin_model_providers_memory_cache: ClassVar[dict[str, tuple[int, float, tuple[ProviderEntity, ...]]]] = {}
class LatestPluginCache(BaseModel):
plugin_id: str
version: str
@ -75,14 +79,22 @@ class PluginService:
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
REDIS_TTL = 60 * 5 # 5 minutes
PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:"
PLUGIN_MODEL_PROVIDERS_GENERATION_REDIS_KEY_PREFIX = "plugin_model_providers_generation:tenant_id:"
PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed)
# Mirror the detail-panel endpoint query size so list reconciliation and
# the visible endpoint drawer exercise the same daemon pagination path.
ENDPOINT_RECONCILIATION_PAGE_SIZE = 100
@classmethod
def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str:
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
def _get_plugin_model_providers_cache_key(cls, tenant_id: str, generation: int | None = None) -> str:
if generation is None:
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}:generation:{generation}"
@classmethod
def _get_plugin_model_providers_generation_cache_key(cls, tenant_id: str) -> str:
return f"{cls.PLUGIN_MODEL_PROVIDERS_GENERATION_REDIS_KEY_PREFIX}{tenant_id}"
@staticmethod
def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str:
@ -115,29 +127,133 @@ class PluginService:
return declaration
@classmethod
def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None:
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
def _copy_provider_entities(cls, providers: Sequence[ProviderEntity]) -> tuple[ProviderEntity, ...]:
return tuple(provider.model_copy(deep=True) for provider in providers)
@classmethod
def _load_plugin_model_providers_generation(cls, tenant_id: str) -> int | None:
cache_key = cls._get_plugin_model_providers_generation_cache_key(tenant_id)
try:
cached_providers = redis_client.get(cache_key)
cached_generation = redis_client.get(cache_key)
except (RedisError, RuntimeError):
logger.warning("Failed to read plugin model provider generation for tenant %s.", tenant_id, exc_info=True)
return None
if cached_generation is None:
return 0
try:
return int(cached_generation)
except (TypeError, ValueError):
logger.warning(
"Invalid plugin model provider generation for tenant %s; deleting cache marker.",
tenant_id,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError):
logger.warning(
"Failed to delete invalid plugin model provider generation for tenant %s.",
tenant_id,
exc_info=True,
)
return None
@classmethod
def _load_in_memory_plugin_model_providers(
cls, memory_cache_key: str, generation: int
) -> tuple[ProviderEntity, ...] | None:
cached_entry = cls._plugin_model_providers_memory_cache.get(memory_cache_key)
if cached_entry is None:
return None
cached_generation, expires_at, providers = cached_entry
if cached_generation != generation or time.monotonic() >= expires_at:
cls._plugin_model_providers_memory_cache.pop(memory_cache_key, None)
return None
return cls._copy_provider_entities(providers)
@classmethod
def _store_in_memory_plugin_model_providers(
cls, memory_cache_key: str, generation: int, providers: Sequence[ProviderEntity]
) -> None:
ttl = dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL
if ttl <= 0:
cls._plugin_model_providers_memory_cache.pop(memory_cache_key, None)
return
cls._plugin_model_providers_memory_cache[memory_cache_key] = (
generation,
time.monotonic() + ttl,
cls._copy_provider_entities(providers),
)
@classmethod
def _load_cached_plugin_model_providers(
cls, tenant_id: str, *, client: PluginModelClient | None = None
) -> tuple[ProviderEntity, ...] | None:
generation = cls._load_plugin_model_providers_generation(tenant_id)
if generation is not None:
in_memory_cached_providers = cls._load_in_memory_plugin_model_providers(tenant_id, generation)
if in_memory_cached_providers is not None:
return in_memory_cached_providers
cache_keys = []
if generation is not None:
cache_keys.append(cls._get_plugin_model_providers_cache_key(tenant_id, generation))
if generation == 0:
cache_keys.append(cls._get_plugin_model_providers_cache_key(tenant_id))
if not cache_keys:
return None
try:
cached_provider_entries = redis_client.mget(cache_keys)
except (RedisError, RuntimeError):
logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True)
return None
if not cached_providers:
if len(cached_provider_entries) != len(cache_keys):
logger.warning(
"Unexpected cached plugin model providers response size for tenant %s.",
tenant_id,
)
return None
try:
return tuple(_provider_entities_adapter.validate_json(cached_providers))
except (TypeError, ValueError, ValidationError):
logger.warning(
"Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True
)
cls.invalidate_plugin_model_providers_cache(tenant_id)
return None
for cache_key, cached_providers in zip(cache_keys, cached_provider_entries):
if not cached_providers:
continue
try:
providers = tuple(_provider_entities_adapter.validate_json(cached_providers))
if generation is not None:
cls._store_in_memory_plugin_model_providers(tenant_id, generation, providers)
return providers
except (TypeError, ValueError, ValidationError):
logger.warning(
"Invalid cached plugin model providers for tenant %s; deleting cache key %s.",
tenant_id,
cache_key,
exc_info=True,
)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError):
logger.warning(
"Failed to delete invalid cached plugin model providers for tenant %s.",
tenant_id,
exc_info=True,
)
return None
@classmethod
def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None:
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
def _store_cached_plugin_model_providers(
cls, tenant_id: str, generation: int, providers: Sequence[ProviderEntity]
) -> None:
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id, generation)
try:
payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8")
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload)
@ -146,9 +262,15 @@ class PluginService:
@classmethod
def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None:
"""Delete the tenant-scoped plugin model provider list cache."""
"""Invalidate tenant-scoped provider metadata across Redis and worker-local mirrors."""
cls._plugin_model_providers_memory_cache.pop(tenant_id, None)
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
generation_key = cls._get_plugin_model_providers_generation_cache_key(tenant_id)
try:
redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id))
pipe = redis_client.pipeline(transaction=False)
pipe.delete(cache_key)
pipe.incr(generation_key)
pipe.execute()
except (RedisError, RuntimeError):
logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True)
@ -163,7 +285,7 @@ class PluginService:
are intentionally owned by this service so tenant isolation and cache
expiry are handled in one place.
"""
cached_providers = cls._load_cached_plugin_model_providers(tenant_id)
cached_providers = cls._load_cached_plugin_model_providers(tenant_id, client=client)
if cached_providers is not None:
return cached_providers
@ -171,7 +293,12 @@ class PluginService:
providers = tuple(
cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id)
)
cls._store_cached_plugin_model_providers(tenant_id, providers)
if not providers:
return providers
generation = cls._load_plugin_model_providers_generation(tenant_id)
if generation is not None:
cls._store_in_memory_plugin_model_providers(tenant_id, generation, providers)
cls._store_cached_plugin_model_providers(tenant_id, generation, providers)
return providers
@staticmethod

View File

@ -433,10 +433,9 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
mock_assembly = Mock()
mock_assembly.model_runtime = Mock()
mock_assembly.model_runtime.get_model_schema.return_value = mock_schema
mock_assembly.model_provider_factory = mock_factory
with (
@ -455,13 +454,12 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_assembly_builder.call_count == 2
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=mock_assembly.model_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
mock_factory.get_model_schema.assert_called_once_with(
mock_assembly.model_runtime.get_model_schema.assert_called_once_with(
provider="openai",
model_type=ModelType.LLM,
model="gpt-4o",
@ -472,18 +470,13 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None:
configuration = _build_provider_configuration()
bound_runtime = Mock()
bound_runtime.get_model_schema.return_value = _build_ai_model("gpt-4o")
configuration.bind_model_runtime(bound_runtime)
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
with (
patch(
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
) as mock_factory_cls,
patch("core.entities.provider_configuration.ModelProviderFactory") as mock_factory_cls,
patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder,
patch(
"core.entities.provider_configuration.create_model_type_instance",
@ -494,16 +487,20 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_cls.call_count == 2
mock_factory_cls.assert_called_with(runtime=bound_runtime)
assert model_schema == bound_runtime.get_model_schema.return_value
mock_factory_cls.assert_not_called()
mock_assembly_builder.assert_not_called()
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=bound_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
bound_runtime.get_model_schema.assert_called_once_with(
provider="openai",
model_type=ModelType.LLM,
model="gpt-4o",
credentials={"api_key": "x"},
)
def test_get_provider_model_returns_none_when_model_not_found() -> None:
@ -544,6 +541,99 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
assert [model.model for model in active_models] == ["b-model"]
def test_get_provider_models_system_filters_requested_model() -> None:
configuration = _build_provider_configuration()
provider_schema = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=[_build_ai_model("a-model"), _build_ai_model("target-model"), _build_ai_model("b-model")],
)
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = provider_schema
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
models = configuration.get_provider_models(
model_type=ModelType.LLM,
only_active=False,
model="target-model",
)
assert [model.model for model in models] == ["target-model"]
def test_get_provider_models_system_customizable_filters_requested_restricted_model() -> None:
provider = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL],
)
system_configuration = SystemConfiguration(
enabled=True,
credentials={"api_key": "test-key"},
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
QuotaConfiguration(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=1_000,
quota_used=0,
is_valid=True,
restrict_models=[
RestrictModel(model="target-model", base_model_name="base-model", model_type=ModelType.LLM),
RestrictModel(model="other-model", base_model_name="base-model", model_type=ModelType.LLM),
],
)
],
)
provider_schema = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=[],
)
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = provider_schema
with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}):
configuration = ProviderConfiguration(
tenant_id="tenant-1",
provider=provider,
preferred_provider_type=ProviderType.SYSTEM,
using_provider_type=ProviderType.SYSTEM,
system_configuration=system_configuration,
custom_configuration=CustomConfiguration(provider=None, models=[]),
model_settings=[],
)
with (
patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
),
patch.object(
ProviderConfiguration,
"get_model_schema",
side_effect=lambda *args, **kwargs: _build_ai_model(kwargs["model"]),
) as mock_get_model_schema,
):
models = configuration.get_provider_models(
model_type=ModelType.LLM,
only_active=False,
model="target-model",
)
assert [model.model for model in models] == ["target-model"]
mock_get_model_schema.assert_called_once()
assert mock_get_model_schema.call_args.kwargs["model"] == "target-model"
def test_get_custom_provider_models_sets_status_for_removed_credentials_and_invalid_lb_configs() -> None:
configuration = _build_provider_configuration()
configuration.using_provider_type = ProviderType.CUSTOM
@ -611,6 +701,48 @@ def test_get_custom_provider_models_sets_status_for_removed_credentials_and_inva
assert invalid_lb_map["custom-model"] is True
def test_get_custom_provider_models_filters_requested_base_model() -> None:
configuration = _build_provider_configuration()
configuration.using_provider_type = ProviderType.CUSTOM
configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"})
provider_schema = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=[_build_ai_model("base-model"), _build_ai_model("target-model")],
)
models = configuration._get_custom_provider_models(
model_types=[ModelType.LLM],
provider_schema=provider_schema,
model_setting_map={},
model="target-model",
)
assert [model.model for model in models] == ["target-model"]
def test_get_provider_models_reuses_cached_provider_schema() -> None:
configuration = _build_provider_configuration()
provider_schema = ProviderEntity(
provider="openai",
label=I18nObject(en_US="OpenAI"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=[_build_ai_model("a-model"), _build_ai_model("b-model")],
)
configuration.provider = provider_schema
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
) as mock_assembly_builder:
configuration.get_provider_models(model_type=ModelType.LLM, model="a-model")
configuration.get_provider_models(model_type=ModelType.LLM, model="b-model")
mock_assembly_builder.assert_not_called()
def test_validator_adds_predefined_model_for_customizable_provider_with_restrictions() -> None:
provider = ProviderEntity(
provider="openai",
@ -1402,25 +1534,22 @@ def test_system_and_custom_provider_model_helpers_cover_remaining_skip_paths() -
return _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING)
return _build_ai_model("target")
with patch(
"core.entities.provider_configuration.original_provider_configurate_methods",
{"openai": [ConfigurateMethod.CUSTOMIZABLE_MODEL]},
):
with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema):
system_models = configuration._get_system_provider_models(
model_types=[ModelType.LLM],
provider_schema=provider_schema,
model_setting_map={
ModelType.LLM: {
"target": ModelSettings(
model="target",
model_type=ModelType.LLM,
enabled=False,
load_balancing_configs=[],
)
}
},
)
configuration._original_provider_configurate_methods = (ConfigurateMethod.CUSTOMIZABLE_MODEL,)
with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema):
system_models = configuration._get_system_provider_models(
model_types=[ModelType.LLM],
provider_schema=provider_schema,
model_setting_map={
ModelType.LLM: {
"target": ModelSettings(
model="target",
model_type=ModelType.LLM,
enabled=False,
load_balancing_configs=[],
)
}
},
)
assert any(model.model == "target" and model.status == ModelStatus.DISABLED for model in system_models)
configuration.using_provider_type = ProviderType.CUSTOM

View File

@ -28,6 +28,9 @@ class _FakeRedis:
def get(self, key: str) -> str | None:
return self._values.get(key)
def mget(self, keys: list[str]) -> list[str | None]:
return [self.get(key) for key in keys]
def setex(self, key: str, ttl: int, value: str) -> None:
self._values[key] = value
self.setex_calls.append((key, ttl, value))
@ -36,6 +39,13 @@ class _FakeRedis:
self._values.pop(key, None)
@pytest.fixture(autouse=True)
def clear_plugin_model_provider_memory_cache() -> None:
PluginService._plugin_model_providers_memory_cache.clear()
yield
PluginService._plugin_model_providers_memory_cache.clear()
def _build_model_schema() -> AIModelEntity:
return AIModelEntity(
model="gpt-4o-mini",
@ -329,6 +339,7 @@ class TestPluginModelRuntime:
"redis_client",
SimpleNamespace(
get=Mock(return_value=None),
mget=Mock(return_value=[None, None]),
delete=Mock(),
setex=Mock(),
),

View File

@ -345,6 +345,62 @@ def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_confi
provider_model.raise_for_status.assert_called_once()
def test_fetch_model_config_reuses_validated_provider_model_from_dify_credentials_provider(
model_config: ModelConfigWithCredentialsEntity,
):
mock_provider_manager = mock.MagicMock()
mock_configurations = mock.MagicMock()
mock_provider_configuration = mock.MagicMock()
mock_provider_model = mock.MagicMock()
mock_model_factory = mock.MagicMock(spec=DifyModelFactory)
mock_configurations.get.return_value = mock_provider_configuration
mock_provider_configuration.get_provider_model.return_value = mock_provider_model
mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"}
mock_provider_manager.get_configurations.return_value = mock_configurations
run_context = DifyRunContext(
tenant_id="tenant",
app_id="app",
user_id="user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
credentials_provider = DifyCredentialsProvider(
run_context=run_context,
provider_manager=mock_provider_manager,
)
model_instance = mock.MagicMock(
model_type_instance=model_config.provider_model_bundle.model_type_instance,
provider_model_bundle=model_config.provider_model_bundle,
)
mock_model_factory.init_model_instance.return_value = model_instance
with mock.patch.object(
model_instance.model_type_instance.__class__,
"get_model_schema",
return_value=model_config.model_schema,
autospec=True,
):
fetch_model_config(
node_data_model=ModelConfig(
provider="openai",
name="gpt-3.5-turbo",
mode="chat",
completion_params={},
),
credentials_provider=credentials_provider,
model_factory=mock_model_factory,
)
mock_provider_configuration.get_provider_model.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-3.5-turbo",
)
mock_provider_model.raise_for_status.assert_called_once()
def test_dify_model_access_adapters_call_managers():
mock_provider_manager = mock.MagicMock()
mock_model_manager = mock.MagicMock()

View File

@ -1,8 +1,9 @@
import datetime
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, call, patch
import pytest
from pydantic import TypeAdapter
from redis import RedisError
@ -13,6 +14,15 @@ from graphon.model_runtime.entities.provider_entities import ConfigurateMethod,
MODULE = "core.plugin.plugin_service"
@pytest.fixture(autouse=True)
def clear_plugin_model_provider_memory_cache() -> None:
from core.plugin.plugin_service import PluginService
PluginService._plugin_model_providers_memory_cache.clear()
yield
PluginService._plugin_model_providers_memory_cache.clear()
class _FakeSession:
def __init__(self) -> None:
self.execute = Mock()
@ -68,6 +78,17 @@ def _build_install_task(*, task_id: str = "task-1", status: PluginInstallTaskSta
)
def _provider_cache_key(tenant_id: str, generation: int | None = None) -> str:
if generation is None:
return f"plugin_model_providers:tenant_id:{tenant_id}"
return f"plugin_model_providers:tenant_id:{tenant_id}:generation:{generation}"
def _provider_generation_key(tenant_id: str) -> str:
return f"plugin_model_providers_generation:tenant_id:{tenant_id}"
class TestFetchLatestPluginVersion:
def test_skips_marketplace_fetch_when_disabled(self) -> None:
"""Cache misses stay None; marketplace is never called when disabled."""
@ -120,9 +141,13 @@ class TestPluginModelProviderCache:
"""A valid tenant cache entry is reused across runtime calls without plugin daemon access."""
cached_provider = _build_provider_entity()
cached_payload = TypeAdapter(list[ProviderEntity]).dump_json([cached_provider]).decode("utf-8")
generation_key = _provider_generation_key("tenant-1")
cache_key = _provider_cache_key("tenant-1", 0)
legacy_cache_key = _provider_cache_key("tenant-1")
with patch(f"{MODULE}.redis_client") as redis_client:
redis_client.get.return_value = cached_payload
redis_client.get.return_value = None
redis_client.mget.return_value = [cached_payload, None]
from core.plugin.plugin_service import PluginService
@ -132,14 +157,20 @@ class TestPluginModelProviderCache:
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
client.fetch_model_providers.assert_not_called()
redis_client.setex.assert_not_called()
redis_client.get.assert_called_once_with(generation_key)
redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key])
def test_fetch_plugin_model_providers_deletes_invalid_cache_and_refetches(self) -> None:
"""Invalid cache payloads are tenant-scoped invalidated before falling back to the daemon."""
"""Invalid generation-scoped cache payloads are removed before falling back to the daemon."""
generation_key = _provider_generation_key("tenant-1")
cache_key = _provider_cache_key("tenant-1", 0)
legacy_cache_key = _provider_cache_key("tenant-1")
with (
patch(f"{MODULE}.redis_client") as redis_client,
patch(f"{MODULE}.dify_config") as mock_config,
):
redis_client.get.return_value = "not-json"
redis_client.get.side_effect = [None, None]
redis_client.mget.return_value = ["not-json", None]
mock_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL = 86400
client = Mock()
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
@ -148,12 +179,13 @@ class TestPluginModelProviderCache:
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
cache_key = "plugin_model_providers:tenant_id:tenant-1"
redis_client.delete.assert_called_once_with(cache_key)
redis_client.setex.assert_called_once()
assert redis_client.setex.call_args.args[0] == cache_key
assert redis_client.setex.call_args.args[1] == 86400
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
redis_client.get.assert_has_calls([call(generation_key), call(generation_key)])
redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key])
def test_fetch_plugin_model_providers_refetches_when_cache_read_fails(self) -> None:
"""Redis read failures do not block provider discovery for the tenant."""
@ -169,10 +201,29 @@ class TestPluginModelProviderCache:
client.fetch_model_providers.assert_called_once_with("tenant-1")
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
def test_fetch_plugin_model_providers_refetches_when_cached_payload_batch_read_fails(self) -> None:
"""Redis mget failures do not block provider discovery for the tenant."""
cache_key = _provider_cache_key("tenant-1", 0)
legacy_cache_key = _provider_cache_key("tenant-1")
with patch(f"{MODULE}.redis_client") as redis_client:
redis_client.get.return_value = None
redis_client.mget.side_effect = RedisError("redis unavailable")
client = Mock()
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
from core.plugin.plugin_service import PluginService
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
client.fetch_model_providers.assert_called_once_with("tenant-1")
redis_client.mget.assert_called_once_with([cache_key, legacy_cache_key])
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
def test_fetch_plugin_model_providers_returns_fresh_result_when_cache_write_fails(self) -> None:
"""Redis write failures are non-fatal after fresh provider data has been fetched."""
with patch(f"{MODULE}.redis_client") as redis_client:
redis_client.get.return_value = None
redis_client.mget.return_value = [None, None]
redis_client.setex.side_effect = RedisError("redis unavailable")
client = Mock()
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
@ -191,6 +242,7 @@ class TestPluginModelProviderCache:
patch(f"{MODULE}.PluginModelClient") as client_cls,
):
redis_client.get.return_value = None
redis_client.mget.return_value = [None, None]
client = client_cls.return_value
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
@ -202,23 +254,98 @@ class TestPluginModelProviderCache:
client.fetch_model_providers.assert_called_once_with("tenant-1")
assert [provider.provider for provider in result] == ["langgenius/openai/openai"]
def test_invalidate_plugin_model_providers_cache_uses_tenant_cache_key(self) -> None:
with patch(f"{MODULE}.redis_client") as redis_client:
def test_fetch_plugin_model_providers_reuses_process_local_cache(self) -> None:
generation_key = _provider_generation_key("tenant-1")
with (
patch(f"{MODULE}.redis_client") as redis_client,
patch(f"{MODULE}.PluginModelClient") as client_cls,
):
redis_client.get.side_effect = [None, None, None]
redis_client.mget.return_value = [None, None]
client = client_cls.return_value
client.fetch_model_providers.return_value = [_build_plugin_model_provider()]
from core.plugin.plugin_service import PluginService
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
first_result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1")
redis_client.get.reset_mock()
redis_client.mget.reset_mock()
redis_client.setex.reset_mock()
client.fetch_model_providers.reset_mock()
redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1")
second_result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1")
def test_invalidate_plugin_model_providers_cache_ignores_redis_delete_failure(self) -> None:
redis_client.get.assert_called_once_with(generation_key)
redis_client.mget.assert_not_called()
redis_client.setex.assert_not_called()
client.fetch_model_providers.assert_not_called()
assert [provider.provider for provider in second_result] == ["langgenius/openai/openai"]
assert second_result[0] == first_result[0]
assert second_result[0] is not first_result[0]
def test_invalidate_plugin_model_providers_cache_uses_redis_pipeline(self) -> None:
with patch(f"{MODULE}.redis_client") as redis_client:
redis_client.delete.side_effect = RedisError("redis unavailable")
pipe = redis_client.pipeline.return_value
from core.plugin.plugin_service import PluginService
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
redis_client.delete.assert_called_once_with("plugin_model_providers:tenant_id:tenant-1")
redis_client.pipeline.assert_called_once_with(transaction=False)
pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1"))
pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1"))
pipe.execute.assert_called_once_with()
def test_invalidate_plugin_model_providers_cache_ignores_redis_pipeline_failure(self) -> None:
with patch(f"{MODULE}.redis_client") as redis_client:
pipe = redis_client.pipeline.return_value
pipe.execute.side_effect = RedisError("redis unavailable")
from core.plugin.plugin_service import PluginService
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
redis_client.pipeline.assert_called_once_with(transaction=False)
pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1"))
pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1"))
pipe.execute.assert_called_once_with()
def test_invalidate_plugin_model_providers_cache_clears_process_local_cache(self) -> None:
with patch(f"{MODULE}.redis_client") as redis_client:
pipe = redis_client.pipeline.return_value
from core.plugin.plugin_service import PluginService
PluginService._store_in_memory_plugin_model_providers("tenant-1", 0, [_build_provider_entity()])
PluginService.invalidate_plugin_model_providers_cache("tenant-1")
assert PluginService._plugin_model_providers_memory_cache == {}
redis_client.pipeline.assert_called_once_with(transaction=False)
pipe.delete.assert_called_once_with(_provider_cache_key("tenant-1"))
pipe.incr.assert_called_once_with(_provider_generation_key("tenant-1"))
pipe.execute.assert_called_once_with()
def test_fetch_plugin_model_providers_ignores_stale_process_local_cache_after_generation_bump(self) -> None:
generation_key = _provider_generation_key("tenant-1")
new_cache_key = _provider_cache_key("tenant-1", 1)
with patch(f"{MODULE}.redis_client") as redis_client:
redis_client.get.side_effect = [b"1", b"1"]
redis_client.mget.return_value = [None]
client = Mock()
client.fetch_model_providers.return_value = [_build_plugin_model_provider(provider="anthropic")]
from core.plugin.plugin_service import PluginService
PluginService._store_in_memory_plugin_model_providers("tenant-1", 0, [_build_provider_entity()])
result = PluginService.fetch_plugin_model_providers(tenant_id="tenant-1", client=client)
client.fetch_model_providers.assert_called_once_with("tenant-1")
redis_client.get.assert_has_calls([call(generation_key), call(generation_key)])
redis_client.mget.assert_called_once_with([new_cache_key])
redis_client.setex.assert_called_once()
assert redis_client.setex.call_args.args[0] == new_cache_key
assert PluginService._plugin_model_providers_memory_cache["tenant-1"][0] == 1
assert [provider.provider for provider in result] == ["langgenius/anthropic/anthropic"]
class TestPluginListEndpointCounts: