mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
fix tests
This commit is contained in:
parent
131facbc65
commit
187e12956a
@ -24,6 +24,7 @@ from core.entities.provider_entities import (
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from graphon.model_runtime import ModelRuntime
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
@ -33,7 +34,6 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@ -1392,10 +1392,12 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
return create_model_type_instance(
|
||||
factory=model_provider_factory, provider=self.provider.provider, model_type=model_type
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_provider_factory
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
@ -44,8 +44,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
provider=openai_provider_name, model_type=ModelType.MODERATION
|
||||
model_type_instance = create_model_type_instance(
|
||||
factory=model_provider_factory, provider=openai_provider_name, model_type=ModelType.MODERATION
|
||||
)
|
||||
model_type_instance = cast(ModerationModel, model_type_instance)
|
||||
moderation_result = model_type_instance.invoke(
|
||||
|
||||
@ -3,6 +3,14 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -10,6 +18,15 @@ if TYPE_CHECKING:
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
_MODEL_TYPE_CLASS_MAP: dict[ModelType, type[AIModel]] = {
|
||||
ModelType.LLM: LargeLanguageModel,
|
||||
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
|
||||
ModelType.RERANK: RerankModel,
|
||||
ModelType.SPEECH2TEXT: Speech2TextModel,
|
||||
ModelType.MODERATION: ModerationModel,
|
||||
ModelType.TTS: TTSModel,
|
||||
}
|
||||
|
||||
|
||||
class PluginModelAssembly:
|
||||
"""Compose request-scoped model views on top of a single plugin runtime."""
|
||||
@ -87,3 +104,30 @@ def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None
|
||||
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
|
||||
"""Create a tenant-bound model manager for service flows."""
|
||||
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager
|
||||
|
||||
|
||||
def create_model_type_instance(
|
||||
factory: ModelProviderFactory,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
"""Instantiate the AIModel subclass for *model_type* backed by *factory*'s runtime.
|
||||
|
||||
This replaces ``ModelProviderFactory.get_model_type_instance`` which was
|
||||
removed in graphon 0.3.0. The mapping from ModelType to concrete AIModel
|
||||
subclass is maintained here so that callers do not need to know the
|
||||
subclass constructors.
|
||||
|
||||
:param factory: factory whose ``runtime`` and provider resolution are used.
|
||||
:param provider: provider identifier (canonical or short name).
|
||||
:param model_type: the model type to instantiate.
|
||||
:returns: an AIModel subclass instance wired to the factory's runtime.
|
||||
:raises ValueError: if *model_type* is not supported.
|
||||
"""
|
||||
model_class = _MODEL_TYPE_CLASS_MAP.get(model_type)
|
||||
if model_class is None:
|
||||
msg = f"Unsupported model type: {model_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
provider_entity = factory.get_model_provider(provider)
|
||||
return model_class(provider_schema=provider_entity, model_runtime=factory.runtime)
|
||||
|
||||
@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
data: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_provider_factory
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import ProviderType
|
||||
|
||||
@ -16,7 +16,9 @@ def get_mocked_fetch_model_config(
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
model_type_instance = create_model_type_instance(
|
||||
factory=model_provider_factory, provider=provider, model_type=ModelType.LLM,
|
||||
)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
|
||||
@ -429,20 +429,25 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
|
||||
mock_factory = Mock()
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
) as mock_factory_builder:
|
||||
with (
|
||||
patch(
|
||||
"core.entities.provider_configuration.create_plugin_model_provider_factory",
|
||||
return_value=mock_factory,
|
||||
) as mock_factory_builder,
|
||||
patch(
|
||||
"core.plugin.impl.model_runtime_factory.create_model_type_instance",
|
||||
return_value=mock_model_type_instance,
|
||||
) as mock_create_instance,
|
||||
):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
|
||||
assert model_type_instance is mock_model_type_instance
|
||||
assert model_schema is mock_schema
|
||||
assert mock_factory_builder.call_count == 2
|
||||
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
|
||||
mock_create_instance.assert_called_once_with(factory=mock_factory, provider="openai", model_type=ModelType.LLM)
|
||||
mock_factory.get_model_schema.assert_called_once_with(
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
@ -459,7 +464,6 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
mock_factory = Mock()
|
||||
mock_model_type_instance = Mock()
|
||||
mock_schema = _build_ai_model("gpt-4o")
|
||||
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
|
||||
mock_factory.get_model_schema.return_value = mock_schema
|
||||
|
||||
with (
|
||||
@ -467,6 +471,10 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
|
||||
) as mock_factory_cls,
|
||||
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
|
||||
patch(
|
||||
"core.plugin.impl.model_runtime_factory.create_model_type_instance",
|
||||
return_value=mock_model_type_instance,
|
||||
) as mock_create_instance,
|
||||
):
|
||||
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
|
||||
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
|
||||
@ -476,6 +484,7 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
|
||||
assert mock_factory_cls.call_count == 2
|
||||
mock_factory_cls.assert_called_with(runtime=bound_runtime)
|
||||
mock_factory_builder.assert_not_called()
|
||||
mock_create_instance.assert_called_once_with(factory=mock_factory, provider="openai", model_type=ModelType.LLM)
|
||||
|
||||
|
||||
def test_get_provider_model_returns_none_when_model_not_found() -> None:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
@ -68,8 +69,9 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix
|
||||
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||
|
||||
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
factory = Mock()
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
mocker.patch("core.helper.moderation.create_model_type_instance", return_value=moderation_model)
|
||||
|
||||
assert (
|
||||
check_moderation(
|
||||
@ -119,8 +121,9 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi
|
||||
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
|
||||
|
||||
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
|
||||
factory = Mock()
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
mocker.patch("core.helper.moderation.create_model_type_instance", return_value=moderation_model)
|
||||
|
||||
assert (
|
||||
check_moderation(
|
||||
@ -147,8 +150,9 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo
|
||||
failing_model = SimpleNamespace(
|
||||
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
|
||||
factory = Mock()
|
||||
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
|
||||
mocker.patch("core.helper.moderation.create_model_type_instance", return_value=failing_model)
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
|
||||
check_moderation(
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance
|
||||
from graphon.model_runtime.entities.common_entities import I18nObject
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -107,7 +108,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
|
||||
|
||||
|
||||
def test_model_provider_factory_requires_runtime() -> None:
|
||||
with pytest.raises(ValueError, match="model_runtime is required"):
|
||||
with pytest.raises(ValueError, match="runtime is required"):
|
||||
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
|
||||
|
||||
def test_model_provider_factory_raises_for_unknown_provider() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime([
|
||||
runtime=_FakeModelRuntime([
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
@ -384,7 +385,7 @@ def test_model_provider_factory_builds_model_type_instances(
|
||||
expected_type: type[object],
|
||||
) -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime([
|
||||
runtime=_FakeModelRuntime([
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
@ -394,14 +395,14 @@ def test_model_provider_factory_builds_model_type_instances(
|
||||
)
|
||||
)
|
||||
|
||||
instance = factory.get_model_type_instance("openai", model_type)
|
||||
instance = create_model_type_instance(factory=factory, provider="openai", model_type=model_type)
|
||||
|
||||
assert isinstance(instance, expected_type)
|
||||
|
||||
|
||||
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
|
||||
factory = ModelProviderFactory(
|
||||
model_runtime=_FakeModelRuntime([
|
||||
runtime=_FakeModelRuntime([
|
||||
_build_provider(
|
||||
provider="langgenius/openai/openai",
|
||||
provider_name="openai",
|
||||
@ -412,4 +413,4 @@ def test_model_provider_factory_rejects_unsupported_model_type() -> None:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
|
||||
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
|
||||
create_model_type_instance(factory=factory, provider="openai", model_type="unsupported") # type: ignore[arg-type]
|
||||
|
||||
@ -32,5 +32,5 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
|
||||
|
||||
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
|
||||
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(runtime=runtime)
|
||||
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
|
||||
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)
|
||||
|
||||
@ -250,7 +250,7 @@ class TestUserAction:
|
||||
("field_name", "value"),
|
||||
[
|
||||
("id", "a" * 21),
|
||||
("title", "b" * 21),
|
||||
("title", "b" * 101),
|
||||
],
|
||||
)
|
||||
def test_user_action_length_limits(self, field_name: str, value: str):
|
||||
|
||||
@ -15,7 +15,7 @@ from core.app.llm.model_access import (
|
||||
)
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_runtime
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.system_variables import default_system_variables
|
||||
from graphon.entities import GraphInitParams
|
||||
@ -243,7 +243,9 @@ def model_config(monkeypatch):
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory(runtime=create_plugin_model_runtime(tenant_id="test"))
|
||||
provider_instance = model_provider_factory.get_model_provider("openai")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
|
||||
model_type_instance = create_model_type_instance(
|
||||
factory=model_provider_factory, provider="openai", model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
# Create a ProviderModelBundle
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
|
||||
@ -188,10 +188,16 @@ def test_run_extract_text(
|
||||
|
||||
if mime_type == "application/pdf":
|
||||
mock_pdf_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
|
||||
if extension:
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", mock_pdf_extract)
|
||||
else:
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", mock_pdf_extract)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
mock_docx_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
|
||||
if extension:
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", mock_docx_extract)
|
||||
else:
|
||||
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", mock_docx_extract)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
@ -439,13 +445,18 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
|
||||
file.extension = extension
|
||||
file.mime_type = mime_type
|
||||
|
||||
extract_patch_target = (
|
||||
"graphon.nodes.document_extractor.node._extract_text_by_file_extension"
|
||||
if extension
|
||||
else "graphon.nodes.document_extractor.node._extract_text_by_mime_type"
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"graphon.nodes.document_extractor.node._download_file_content",
|
||||
return_value=b"excel",
|
||||
),
|
||||
patch(
|
||||
"graphon.nodes.document_extractor.node._extract_text_from_excel",
|
||||
extract_patch_target,
|
||||
return_value="excel text",
|
||||
) as mock_extract,
|
||||
):
|
||||
@ -456,7 +467,6 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
|
||||
)
|
||||
|
||||
assert result == "excel text"
|
||||
mock_extract.assert_called_once_with(b"excel")
|
||||
|
||||
|
||||
def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):
|
||||
|
||||
@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
|
||||
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
|
||||
mock_node_cls.assert_called_once_with(
|
||||
node_id="n-1",
|
||||
config=sentinel.node_data,
|
||||
data=sentinel.node_data,
|
||||
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
|
||||
graph_runtime_state=ANY,
|
||||
runtime=mock_runtime_cls.return_value,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user