mirror of
https://github.com/langgenius/dify.git
synced 2026-06-19 08:31:07 +08:00
fix(workflow): cache provider configurations during graph init (#35447)
This commit is contained in:
parent
c2a5962023
commit
77d6c108e7
@ -70,12 +70,32 @@ class ProviderManager:
|
||||
Request-bound managers may carry caller identity in that runtime, and the
|
||||
resulting ``ProviderConfiguration`` objects must reuse it for downstream
|
||||
model-type and schema lookups.
|
||||
|
||||
Configuration assembly is cached per manager instance so call chains that
|
||||
share one request-scoped manager can reuse the same provider graph instead
|
||||
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
|
||||
when a long-lived manager needs to observe writes performed within the same
|
||||
instance scope.
|
||||
"""
|
||||
|
||||
decoding_rsa_key: Any | None
|
||||
decoding_cipher_rsa: Any | None
|
||||
_model_runtime: ModelRuntime
|
||||
_configurations_cache: dict[str, ProviderConfigurations]
|
||||
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
self._model_runtime = model_runtime
|
||||
self._configurations_cache = {}
|
||||
|
||||
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
|
||||
"""Drop assembled provider configurations cached on this manager instance."""
|
||||
if tenant_id is None:
|
||||
self._configurations_cache.clear()
|
||||
return
|
||||
|
||||
self._configurations_cache.pop(tenant_id, None)
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@ -114,6 +134,10 @@ class ProviderManager:
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
cached_configurations = self._configurations_cache.get(tenant_id)
|
||||
if cached_configurations is not None:
|
||||
return cached_configurations
|
||||
|
||||
# Get all provider records of the workspace
|
||||
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
|
||||
|
||||
@ -273,6 +297,8 @@ class ProviderManager:
|
||||
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
self._configurations_cache[tenant_id] = provider_configurations
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
|
||||
@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration(
|
||||
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
|
||||
|
||||
def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_configuration = Mock()
|
||||
provider_factory = Mock()
|
||||
provider_factory.get_providers.return_value = [mock_provider_entity]
|
||||
custom_configuration = SimpleNamespace(provider=None, models=[])
|
||||
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
|
||||
|
||||
with (
|
||||
patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
|
||||
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
|
||||
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
|
||||
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
|
||||
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
|
||||
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
|
||||
patch.object(manager, "_to_model_settings", return_value=[]),
|
||||
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls,
|
||||
patch(
|
||||
"core.provider_manager.ProviderConfiguration",
|
||||
return_value=provider_configuration,
|
||||
) as mock_provider_configuration,
|
||||
):
|
||||
first = manager.get_configurations("tenant-id")
|
||||
second = manager.get_configurations("tenant-id")
|
||||
|
||||
assert first is second
|
||||
mock_get_all_providers.assert_called_once_with("tenant-id")
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_provider_configuration.assert_called_once()
|
||||
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
|
||||
|
||||
def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_factory = Mock()
|
||||
provider_factory.get_providers.return_value = [mock_provider_entity]
|
||||
custom_configuration = SimpleNamespace(provider=None, models=[])
|
||||
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
|
||||
provider_configuration_first = Mock()
|
||||
provider_configuration_second = Mock()
|
||||
|
||||
with (
|
||||
patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
|
||||
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
|
||||
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
|
||||
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
|
||||
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
|
||||
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
|
||||
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
|
||||
patch.object(manager, "_to_model_settings", return_value=[]),
|
||||
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
|
||||
patch(
|
||||
"core.provider_manager.ProviderConfiguration",
|
||||
side_effect=[provider_configuration_first, provider_configuration_second],
|
||||
) as mock_provider_configuration,
|
||||
):
|
||||
first = manager.get_configurations("tenant-id")
|
||||
manager.clear_configurations_cache("tenant-id")
|
||||
second = manager.get_configurations("tenant-id")
|
||||
|
||||
assert first is not second
|
||||
assert mock_get_all_providers.call_count == 2
|
||||
assert mock_provider_configuration.call_count == 2
|
||||
provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
|
||||
|
||||
def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture):
|
||||
manager = _build_provider_manager(mocker)
|
||||
provider_configuration = Mock()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user