fix(workflow): cache provider configurations during graph init (#35447)

This commit is contained in:
-LAN- 2026-04-21 12:29:35 +08:00 committed by GitHub
parent c2a5962023
commit 77d6c108e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 98 additions and 0 deletions

View File

@ -70,12 +70,32 @@ class ProviderManager:
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
Configuration assembly is cached per manager instance so call chains that
share one request-scoped manager can reuse the same provider graph instead
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
when a long-lived manager needs to observe writes performed within the same
instance scope.
"""
decoding_rsa_key: Any | None
decoding_cipher_rsa: Any | None
_model_runtime: ModelRuntime
_configurations_cache: dict[str, ProviderConfigurations]
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
self._configurations_cache = {}
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
"""Drop assembled provider configurations cached on this manager instance."""
if tenant_id is None:
self._configurations_cache.clear()
return
self._configurations_cache.pop(tenant_id, None)
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@ -114,6 +134,10 @@ class ProviderManager:
:param tenant_id:
:return:
"""
cached_configurations = self._configurations_cache.get(tenant_id)
if cached_configurations is not None:
return cached_configurations
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
@ -273,6 +297,8 @@ class ProviderManager:
provider_configurations[str(provider_id_entity)] = provider_configuration
self._configurations_cache[tenant_id] = provider_configurations
# Return the encapsulated object
return provider_configurations

View File

@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration(
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity):
manager = _build_provider_manager(mocker)
provider_configuration = Mock()
provider_factory = Mock()
provider_factory.get_providers.return_value = [mock_provider_entity]
custom_configuration = SimpleNamespace(provider=None, models=[])
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
with (
patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
patch.object(manager, "_to_model_settings", return_value=[]),
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls,
patch(
"core.provider_manager.ProviderConfiguration",
return_value=provider_configuration,
) as mock_provider_configuration,
):
first = manager.get_configurations("tenant-id")
second = manager.get_configurations("tenant-id")
assert first is second
mock_get_all_providers.assert_called_once_with("tenant-id")
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_provider_configuration.assert_called_once()
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity):
manager = _build_provider_manager(mocker)
provider_factory = Mock()
provider_factory.get_providers.return_value = [mock_provider_entity]
custom_configuration = SimpleNamespace(provider=None, models=[])
system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
provider_configuration_first = Mock()
provider_configuration_second = Mock()
with (
patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
patch.object(manager, "_get_all_provider_model_settings", return_value={}),
patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
patch.object(manager, "_to_system_configuration", return_value=system_configuration),
patch.object(manager, "_to_model_settings", return_value=[]),
patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
patch(
"core.provider_manager.ProviderConfiguration",
side_effect=[provider_configuration_first, provider_configuration_second],
) as mock_provider_configuration,
):
first = manager.get_configurations("tenant-id")
manager.clear_configurations_cache("tenant-id")
second = manager.get_configurations("tenant-id")
assert first is not second
assert mock_get_all_providers.call_count == 2
assert mock_provider_configuration.call_count == 2
provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime)
provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime)
def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture):
manager = _build_provider_manager(mocker)
provider_configuration = Mock()