diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index a63ff39fa5..1d6c377d8f 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import deepcopy from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -15,12 +16,17 @@ from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager + credentials_cache: dict[tuple[str, str], dict[str, Any]] def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: self.tenant_id = tenant_id self.provider_manager = provider_manager or ProviderManager() + self.credentials_cache = {} def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + if (provider_name, model_name) in self.credentials_cache: + return deepcopy(self.credentials_cache[(provider_name, model_name)]) + provider_configurations = self.provider_manager.get_configurations(self.tenant_id) provider_configuration = provider_configurations.get(provider_name) if not provider_configuration: @@ -35,6 +41,7 @@ class DifyCredentialsProvider: if credentials is None: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials) return credentials @@ -44,7 +51,7 @@ class DifyModelFactory: def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: self.tenant_id = tenant_id - self.model_manager = model_manager or ModelManager() + self.model_manager = model_manager or ModelManager(enable_credentials_cache=True) def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: return self.model_manager.get_model_instance( diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0f710a8fcf..7eab84b5bb 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,5 +1,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence +from copy import deepcopy from typing import IO, Any, Literal, Optional, Union, cast, overload from configs import dify_config @@ -33,11 +34,13 @@ class ModelInstance: Model instance class """ - def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): + def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None): self.provider_model_bundle = provider_model_bundle self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider - self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + if credentials is None: + credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + self.credentials = credentials # Runtime LLM invocation fields. self.parameters: Mapping[str, Any] = {} self.stop: Sequence[str] = () @@ -477,8 +480,10 @@ class ModelInstance: class ModelManager: - def __init__(self): + def __init__(self, enable_credentials_cache: bool = False): self._provider_manager = ProviderManager() + self._credentials_cache: dict[tuple[str, str, str, str], Any] = {} + self._enable_credentials_cache = enable_credentials_cache def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: """ @@ -496,7 +501,19 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - return ModelInstance(provider_model_bundle, model) + cred_cache_key = (tenant_id, provider, model_type.value, model) + + if cred_cache_key in self._credentials_cache: + return ModelInstance( + provider_model_bundle, + model, + deepcopy(self._credentials_cache[cred_cache_key]), + ) + + ret = ModelInstance(provider_model_bundle, model) + if self._enable_credentials_cache: + self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials) + return ret def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 5040fcc7e3..daec78b94b 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime from typing import TYPE_CHECKING +from cachetools.func import ttl_cache from pydantic import BaseModel, ConfigDict, Field, model_validator from configs import dify_config @@ -98,7 +99,9 @@ def try_join_default_workspace(account_id: str) -> None: class EnterpriseService: + @classmethod + @ttl_cache(ttl=5) def get_info(cls): return EnterpriseRequest.send_request("GET", "/info")