fix tests

This commit is contained in:
yunlu.wen 2026-04-30 18:02:09 +08:00
parent 131facbc65
commit 187e12956a
13 changed files with 110 additions and 36 deletions

View File

@ -24,6 +24,7 @@ from core.entities.provider_entities import (
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from graphon.model_runtime import ModelRuntime
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@ -33,7 +34,6 @@ from graphon.model_runtime.entities.provider_entities import (
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@ -1392,10 +1392,12 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = self.get_model_provider_factory()
from core.plugin.impl.model_runtime_factory import create_model_type_instance
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
model_provider_factory = self.get_model_provider_factory()
return create_model_type_instance(
factory=model_provider_factory, provider=self.provider.provider, model_type=model_type
)
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None

View File

@ -4,7 +4,7 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_provider_factory
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
@ -44,8 +44,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
model_type_instance = create_model_type_instance(
factory=model_provider_factory, provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)
moderation_result = model_type_instance.invoke(

View File

@ -3,6 +3,14 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
if TYPE_CHECKING:
@ -10,6 +18,15 @@ if TYPE_CHECKING:
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
_MODEL_TYPE_CLASS_MAP: dict[ModelType, type[AIModel]] = {
ModelType.LLM: LargeLanguageModel,
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
ModelType.RERANK: RerankModel,
ModelType.SPEECH2TEXT: Speech2TextModel,
ModelType.MODERATION: ModerationModel,
ModelType.TTS: TTSModel,
}
class PluginModelAssembly:
"""Compose request-scoped model views on top of a single plugin runtime."""
@ -87,3 +104,30 @@ def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
"""Create a tenant-bound model manager for service flows."""
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager
def create_model_type_instance(
factory: ModelProviderFactory,
provider: str,
model_type: ModelType,
) -> AIModel:
"""Instantiate the AIModel subclass for *model_type* backed by *factory*'s runtime.
This replaces ``ModelProviderFactory.get_model_type_instance`` which was
removed in graphon 0.3.0. The mapping from ModelType to concrete AIModel
subclass is maintained here so that callers do not need to know the
subclass constructors.
:param factory: factory whose ``runtime`` and provider resolution are used.
:param provider: provider identifier (canonical or short name).
:param model_type: the model type to instantiate.
:returns: an AIModel subclass instance wired to the factory's runtime.
:raises ValueError: if *model_type* is not supported.
"""
model_class = _MODEL_TYPE_CLASS_MAP.get(model_type)
if model_class is None:
msg = f"Unsupported model type: {model_type}"
raise ValueError(msg)
provider_entity = factory.get_model_provider(provider)
return model_class(provider_schema=provider_entity, model_runtime=factory.runtime)

View File

@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
node_id: str,
config: DatasourceNodeData,
data: DatasourceNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_provider_factory
from graphon.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderType
@ -16,7 +16,9 @@ def get_mocked_fetch_model_config(
credentials: dict,
):
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
model_type_instance = create_model_type_instance(
factory=model_provider_factory, provider=provider, model_type=ModelType.LLM,
)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",

View File

@ -429,20 +429,25 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory.get_model_schema.return_value = mock_schema
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory",
return_value=mock_factory,
) as mock_factory_builder:
with (
patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory",
return_value=mock_factory,
) as mock_factory_builder,
patch(
"core.plugin.impl.model_runtime_factory.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_create_instance,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_builder.call_count == 2
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
mock_create_instance.assert_called_once_with(factory=mock_factory, provider="openai", model_type=ModelType.LLM)
mock_factory.get_model_schema.assert_called_once_with(
provider="openai",
model_type=ModelType.LLM,
@ -459,7 +464,6 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory.get_model_schema.return_value = mock_schema
with (
@ -467,6 +471,10 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
) as mock_factory_cls,
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
patch(
"core.plugin.impl.model_runtime_factory.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_create_instance,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
@ -476,6 +484,7 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
assert mock_factory_cls.call_count == 2
mock_factory_cls.assert_called_with(runtime=bound_runtime)
mock_factory_builder.assert_not_called()
mock_create_instance.assert_called_once_with(factory=mock_factory, provider="openai", model_type=ModelType.LLM)
def test_get_provider_model_returns_none_when_model_not_found() -> None:

View File

@ -1,5 +1,6 @@
from types import SimpleNamespace
from typing import cast
from unittest.mock import Mock
import pytest
from pytest_mock import MockerFixture
@ -68,8 +69,9 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
factory = Mock()
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
mocker.patch("core.helper.moderation.create_model_type_instance", return_value=moderation_model)
assert (
check_moderation(
@ -119,8 +121,9 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
factory = Mock()
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
mocker.patch("core.helper.moderation.create_model_type_instance", return_value=moderation_model)
assert (
check_moderation(
@ -147,8 +150,9 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo
failing_model = SimpleNamespace(
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
factory = Mock()
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
mocker.patch("core.helper.moderation.create_model_type_instance", return_value=failing_model)
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
check_moderation(

View File

@ -2,6 +2,7 @@ from unittest.mock import Mock
import pytest
from core.plugin.impl.model_runtime_factory import create_model_type_instance
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -107,7 +108,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
def test_model_provider_factory_requires_runtime() -> None:
with pytest.raises(ValueError, match="model_runtime is required"):
with pytest.raises(ValueError, match="runtime is required"):
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
def test_model_provider_factory_raises_for_unknown_provider() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime([
runtime=_FakeModelRuntime([
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
@ -384,7 +385,7 @@ def test_model_provider_factory_builds_model_type_instances(
expected_type: type[object],
) -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime([
runtime=_FakeModelRuntime([
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
@ -394,14 +395,14 @@ def test_model_provider_factory_builds_model_type_instances(
)
)
instance = factory.get_model_type_instance("openai", model_type)
instance = create_model_type_instance(factory=factory, provider="openai", model_type=model_type)
assert isinstance(instance, expected_type)
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime([
runtime=_FakeModelRuntime([
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
@ -412,4 +413,4 @@ def test_model_provider_factory_rejects_unsupported_model_type() -> None:
)
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
create_model_type_instance(factory=factory, provider="openai", model_type="unsupported") # type: ignore[arg-type]

View File

@ -32,5 +32,5 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)

View File

@ -250,7 +250,7 @@ class TestUserAction:
("field_name", "value"),
[
("id", "a" * 21),
("title", "b" * 21),
("title", "b" * 101),
],
)
def test_user_action_length_limits(self, field_name: str, value: str):

View File

@ -15,7 +15,7 @@ from core.app.llm.model_access import (
)
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_runtime
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
@ -243,7 +243,9 @@ def model_config(monkeypatch):
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(runtime=create_plugin_model_runtime(tenant_id="test"))
provider_instance = model_provider_factory.get_model_provider("openai")
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
model_type_instance = create_model_type_instance(
factory=model_provider_factory, provider="openai", model_type=ModelType.LLM,
)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(

View File

@ -188,10 +188,16 @@ def test_run_extract_text(
if mime_type == "application/pdf":
mock_pdf_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
if extension:
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", mock_pdf_extract)
else:
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", mock_pdf_extract)
elif mime_type.startswith("application/vnd.openxmlformats"):
mock_docx_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
if extension:
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", mock_docx_extract)
else:
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", mock_docx_extract)
result = document_extractor_node._run()
@ -439,13 +445,18 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
file.extension = extension
file.mime_type = mime_type
extract_patch_target = (
"graphon.nodes.document_extractor.node._extract_text_by_file_extension"
if extension
else "graphon.nodes.document_extractor.node._extract_text_by_mime_type"
)
with (
patch(
"graphon.nodes.document_extractor.node._download_file_content",
return_value=b"excel",
),
patch(
"graphon.nodes.document_extractor.node._extract_text_from_excel",
extract_patch_target,
return_value="excel text",
) as mock_extract,
):
@ -456,7 +467,6 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
)
assert result == "excel text"
mock_extract.assert_called_once_with(b"excel")
def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):

View File

@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
mock_node_cls.assert_called_once_with(
node_id="n-1",
config=sentinel.node_data,
data=sentinel.node_data,
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
graph_runtime_state=ANY,
runtime=mock_runtime_cls.return_value,