mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
update ModelProviderFactory
This commit is contained in:
parent
9acd149469
commit
131facbc65
@ -109,7 +109,7 @@ class ProviderConfiguration(BaseModel):
|
||||
def get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Return a provider factory that preserves any request-bound runtime."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return ModelProviderFactory(runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
|
||||
@ -38,7 +38,7 @@ class PluginModelAssembly:
|
||||
@property
|
||||
def model_provider_factory(self) -> ModelProviderFactory:
|
||||
if self._model_provider_factory is None:
|
||||
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
|
||||
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
|
||||
return self._model_provider_factory
|
||||
|
||||
@property
|
||||
|
||||
@ -165,7 +165,7 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Get all provider entities
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
@ -362,7 +362,7 @@ class ProviderManager:
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
|
||||
|
||||
return DefaultModelEntity(
|
||||
|
||||
@ -474,7 +474,7 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_cls.call_count == 2
|
||||
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
|
||||
mock_factory_cls.assert_called_with(runtime=bound_runtime)
|
||||
mock_factory_builder.assert_not_called()
|
||||
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None:
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
|
||||
|
||||
provider_schema = factory.get_model_provider("openai")
|
||||
|
||||
@ -98,7 +98,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
provider_schema = factory.get_model_provider("openai")
|
||||
|
||||
@ -108,7 +108,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
|
||||
|
||||
def test_model_provider_factory_requires_runtime() -> None:
|
||||
with pytest.raises(ValueError, match="model_runtime is required"):
|
||||
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
|
||||
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
|
||||
@ -119,7 +119,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
result = factory.get_providers()
|
||||
|
||||
@ -133,7 +133,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
|
||||
provider_name="openai",
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
|
||||
|
||||
result = factory.get_provider_schema("openai")
|
||||
|
||||
@ -171,7 +171,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
|
||||
models=[_build_model("rerank-v3", ModelType.RERANK)],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(provider="openai", model_type=ModelType.LLM)
|
||||
|
||||
@ -195,7 +195,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod
|
||||
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
|
||||
),
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(model_type=ModelType.TTS)
|
||||
|
||||
@ -213,7 +213,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider
|
||||
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
|
||||
)
|
||||
]
|
||||
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
|
||||
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
|
||||
|
||||
results = factory.get_models(provider="openai")
|
||||
|
||||
@ -241,7 +241,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
filtered = factory.provider_credentials_validate(
|
||||
provider="openai",
|
||||
@ -257,7 +257,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
|
||||
|
||||
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime([
|
||||
runtime=_FakeModelRuntime([
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
@ -292,7 +292,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
|
||||
)
|
||||
]
|
||||
)
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
filtered = factory.model_credentials_validate(
|
||||
provider="openai",
|
||||
@ -312,7 +312,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
|
||||
|
||||
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime([
|
||||
runtime=_FakeModelRuntime([
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
@ -343,7 +343,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
|
||||
)
|
||||
runtime.get_model_schema.return_value = "schema"
|
||||
runtime.get_provider_icon.return_value = (b"icon", "image/png")
|
||||
factory = ModelProviderFactory(model_runtime=runtime)
|
||||
factory = ModelProviderFactory(runtime=runtime)
|
||||
|
||||
assert (
|
||||
factory.get_model_schema(
|
||||
|
||||
@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
|
||||
assert assembly.model_manager is model_manager
|
||||
|
||||
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(runtime=runtime)
|
||||
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)
|
||||
|
||||
@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
|
||||
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
assert result is not None
|
||||
assert result.model == "gpt-4"
|
||||
assert result.provider.provider == "openai"
|
||||
@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock
|
||||
result = manager.get_configurations("tenant-id")
|
||||
|
||||
expected_alias = str(ModelProviderID("openai"))
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
assert result.tenant_id == "tenant-id"
|
||||
assert expected_alias in provider_records
|
||||
assert expected_alias in provider_model_records
|
||||
@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF
|
||||
|
||||
assert first is second
|
||||
mock_get_all_providers.assert_called_once_with("tenant-id")
|
||||
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
|
||||
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
|
||||
mock_provider_configuration.assert_called_once()
|
||||
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
|
||||
|
||||
|
||||
@ -241,7 +241,7 @@ def model_config(monkeypatch):
|
||||
)
|
||||
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test"))
|
||||
model_provider_factory = ModelProviderFactory(runtime=create_plugin_model_runtime(tenant_id="test"))
|
||||
provider_instance = model_provider_factory.get_model_provider("openai")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user