diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 786094f295..4343a056dd 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -243,6 +243,11 @@ class PluginConfig(BaseSettings): default=15728640 * 12, ) + PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field( + description="TTL in seconds for caching plugin model schemas in Redis", + default=24 * 60 * 60, + ) + class MarketplaceConfig(BaseSettings): """ diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 7c16bc231f..c52dcf8a57 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -29,12 +28,6 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( ContextVar("plugin_model_providers_lock") ) -plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) - -plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( - ContextVar("plugin_model_schemas") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 45f0335c2e..c3e50eaddd 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,10 +1,11 @@ import decimal import hashlib -from threading import Lock +import logging -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from redis import RedisError -import contexts +from configs import dify_config from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -24,6 +25,9 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) class AIModel(BaseModel): @@ -144,34 +148,60 @@ class AIModel(BaseModel): plugin_model_manager = PluginModelClient() cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model_type=self.model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 28f162a928..64538a6779 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,7 +5,11 @@ import logging from collections.abc import Sequence from threading import Lock +from pydantic import ValidationError +from redis import RedisError + import contexts +from configs import dify_config from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -18,6 +22,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from extensions.ext_redis import redis_client from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -175,34 +180,60 @@ class ModelProviderFactory: """ plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + cached_schema_json = None try: - contexts.plugin_model_schemas.get() - except LookupError: - contexts.plugin_model_schemas.set({}) - contexts.plugin_model_schema_lock.set(Lock()) - - with contexts.plugin_model_schema_lock.get(): - if cache_key in contexts.plugin_model_schemas.get(): - return contexts.plugin_model_schemas.get()[cache_key] - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, ) + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning( + "Failed to validate cached plugin model schema for model %s", + model, + exc_info=True, + ) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) - if schema: - contexts.plugin_model_schemas.get()[cache_key] = schema + schema = self.plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials or {}, + ) - return schema + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema def get_models( self,