update ModelProviderFactory

This commit is contained in:
yunlu.wen 2026-04-30 17:00:07 +08:00
parent 9acd149469
commit 131facbc65
8 changed files with 24 additions and 24 deletions

View File

@ -109,7 +109,7 @@ class ProviderConfiguration(BaseModel):
def get_model_provider_factory(self) -> ModelProviderFactory:
"""Return a provider factory that preserves any request-bound runtime."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return ModelProviderFactory(runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:

View File

@ -38,7 +38,7 @@ class PluginModelAssembly:
@property
def model_provider_factory(self) -> ModelProviderFactory:
if self._model_provider_factory is None:
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
return self._model_provider_factory
@property

View File

@ -165,7 +165,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@ -362,7 +362,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(

View File

@ -474,7 +474,7 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_cls.call_count == 2
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
mock_factory_cls.assert_called_with(runtime=bound_runtime)
mock_factory_builder.assert_not_called()

View File

@ -73,7 +73,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None:
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
provider_schema = factory.get_model_provider("openai")
@ -98,7 +98,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
provider_schema = factory.get_model_provider("openai")
@ -108,7 +108,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
def test_model_provider_factory_requires_runtime() -> None:
with pytest.raises(ValueError, match="model_runtime is required"):
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
@ -119,7 +119,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non
supported_model_types=[ModelType.LLM],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
result = factory.get_providers()
@ -133,7 +133,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
result = factory.get_provider_schema("openai")
@ -171,7 +171,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
models=[_build_model("rerank-v3", ModelType.RERANK)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai", model_type=ModelType.LLM)
@ -195,7 +195,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(model_type=ModelType.TTS)
@ -213,7 +213,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai")
@ -241,7 +241,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.provider_credentials_validate(
provider="openai",
@ -257,7 +257,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime([
runtime=_FakeModelRuntime([
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
@ -292,7 +292,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.model_credentials_validate(
provider="openai",
@ -312,7 +312,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime([
runtime=_FakeModelRuntime([
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
@ -343,7 +343,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
)
runtime.get_model_schema.return_value = "schema"
runtime.get_provider_icon.return_value = (b"icon", "image/png")
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
assert (
factory.get_model_schema(

View File

@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
assert assembly.model_manager is model_manager
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(runtime=runtime)
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)

View File

@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
result = manager.get_default_model("tenant-id", ModelType.LLM)
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result is not None
assert result.model == "gpt-4"
assert result.provider.provider == "openai"
@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock
result = manager.get_configurations("tenant-id")
expected_alias = str(ModelProviderID("openai"))
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result.tenant_id == "tenant-id"
assert expected_alias in provider_records
assert expected_alias in provider_model_records
@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF
assert first is second
mock_get_all_providers.assert_called_once_with("tenant-id")
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
mock_provider_configuration.assert_called_once()
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)

View File

@ -241,7 +241,7 @@ def model_config(monkeypatch):
)
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test"))
model_provider_factory = ModelProviderFactory(runtime=create_plugin_model_runtime(tenant_id="test"))
provider_instance = model_provider_factory.get_model_provider("openai")
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)