fix: cache credentials & enterprise calls (#35528)

This commit is contained in:
Yunlu Wen 2026-04-23 23:08:04 +08:00 committed by GitHub
parent e7746cb256
commit 573ec3af9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 5 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -15,12 +16,17 @@ from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
self.tenant_id = tenant_id
self.provider_manager = provider_manager or ProviderManager()
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@ -35,6 +41,7 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@ -44,7 +51,7 @@ class DifyModelFactory:
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
self.tenant_id = tenant_id
self.model_manager = model_manager or ModelManager()
self.model_manager = model_manager or ModelManager(enable_credentials_cache=True)
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
return self.model_manager.get_model_instance(

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, Union, cast, overload
from configs import dify_config
@ -33,11 +34,13 @@ class ModelInstance:
Model instance class
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None):
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@ -477,8 +480,10 @@ class ModelInstance:
class ModelManager:
def __init__(self):
def __init__(self, enable_credentials_cache: bool = False):
self._provider_manager = ProviderManager()
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
"""
@ -496,7 +501,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
return ModelInstance(provider_model_bundle, model)
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@ -5,6 +5,7 @@ import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
@ -98,7 +99,9 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService:
@classmethod
@ttl_cache(ttl=5)
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")