diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 913fc53d81..4c6ea1fdd7 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -24,6 +24,7 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime import ModelRuntime from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, @@ -33,7 +34,6 @@ from graphon.model_runtime.entities.provider_entities import ( ) from graphon.model_runtime.model_providers.base.ai_model import AIModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -1392,10 +1392,12 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = self.get_model_provider_factory() + from core.plugin.impl.model_runtime_factory import create_model_type_instance - # Get model instance of LLM - return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) + model_provider_factory = self.get_model_provider_factory() + return create_model_type_instance( + factory=model_provider_factory, provider=self.provider.provider, model_type=model_type + ) def get_model_schema( self, model_type: ModelType, model: str, credentials: dict[str, Any] | None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index f169f247cf..b17b068c73 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,7 +4,7 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeBadRequestError @@ -44,8 +44,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) # Get model instance of LLM - model_type_instance = model_provider_factory.get_model_type_instance( - provider=openai_provider_name, model_type=ModelType.MODERATION + model_type_instance = create_model_type_instance( + factory=model_provider_factory, provider=openai_provider_name, model_type=ModelType.MODERATION ) model_type_instance = cast(ModerationModel, model_type_instance) moderation_result = model_type_instance.invoke( diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 12510f0243..3f9763626c 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -3,6 +3,14 @@ from __future__ import annotations from typing import TYPE_CHECKING from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory if TYPE_CHECKING: @@ -10,6 +18,15 @@ if TYPE_CHECKING: from core.plugin.impl.model_runtime import PluginModelRuntime from core.provider_manager import ProviderManager +_MODEL_TYPE_CLASS_MAP: dict[ModelType, type[AIModel]] = { + ModelType.LLM: LargeLanguageModel, + ModelType.TEXT_EMBEDDING: TextEmbeddingModel, + ModelType.RERANK: RerankModel, + ModelType.SPEECH2TEXT: Speech2TextModel, + ModelType.MODERATION: ModerationModel, + ModelType.TTS: TTSModel, +} + class PluginModelAssembly: """Compose request-scoped model views on top of a single plugin runtime.""" @@ -87,3 +104,30 @@ def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: """Create a tenant-bound model manager for service flows.""" return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager + + +def create_model_type_instance( + factory: ModelProviderFactory, + provider: str, + model_type: ModelType, +) -> AIModel: + """Instantiate the AIModel subclass for *model_type* backed by *factory*'s runtime. + + This replaces ``ModelProviderFactory.get_model_type_instance`` which was + removed in graphon 0.3.0. The mapping from ModelType to concrete AIModel + subclass is maintained here so that callers do not need to know the + subclass constructors. + + :param factory: factory whose ``runtime`` and provider resolution are used. + :param provider: provider identifier (canonical or short name). + :param model_type: the model type to instantiate. + :returns: an AIModel subclass instance wired to the factory's runtime. + :raises ValueError: if *model_type* is not supported. + """ + model_class = _MODEL_TYPE_CLASS_MAP.get(model_type) + if model_class is None: + msg = f"Unsupported model type: {model_type}" + raise ValueError(msg) + + provider_entity = factory.get_model_provider(provider) + return model_class(provider_schema=provider_entity, model_runtime=factory.runtime) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index f3006c4242..a4ef3d1ea7 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, node_id: str, - config: DatasourceNodeData, + data: DatasourceNodeData, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index a9a2617bae..03846446b9 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_provider_factory from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType @@ -16,7 +16,9 @@ def get_mocked_fetch_model_config( credentials: dict, ): model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") - model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) + model_type_instance = create_model_type_instance( + factory=model_provider_factory, provider=provider, model_type=ModelType.LLM, + ) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 695c89a731..66838ba2ad 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -429,20 +429,25 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory = Mock() mock_model_type_instance = Mock() mock_schema = _build_ai_model("gpt-4o") - mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", - return_value=mock_factory, - ) as mock_factory_builder: + with ( + patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", + return_value=mock_factory, + ) as mock_factory_builder, + patch( + "core.plugin.impl.model_runtime_factory.create_model_type_instance", + return_value=mock_model_type_instance, + ) as mock_create_instance, + ): model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema assert mock_factory_builder.call_count == 2 - mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) + mock_create_instance.assert_called_once_with(factory=mock_factory, provider="openai", model_type=ModelType.LLM) mock_factory.get_model_schema.assert_called_once_with( provider="openai", model_type=ModelType.LLM, @@ -459,7 +464,6 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non mock_factory = Mock() mock_model_type_instance = Mock() mock_schema = _build_ai_model("gpt-4o") - mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema with ( @@ -467,6 +471,10 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory ) as mock_factory_cls, patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + patch( + "core.plugin.impl.model_runtime_factory.create_model_type_instance", + return_value=mock_model_type_instance, + ) as mock_create_instance, ): model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) @@ -476,6 +484,7 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non assert mock_factory_cls.call_count == 2 mock_factory_cls.assert_called_with(runtime=bound_runtime) mock_factory_builder.assert_not_called() + mock_create_instance.assert_called_once_with(factory=mock_factory, provider="openai", model_type=ModelType.LLM) def test_get_provider_model_returns_none_when_model_not_found() -> None: diff --git a/api/tests/unit_tests/core/helper/test_moderation.py b/api/tests/unit_tests/core/helper/test_moderation.py index a0dfa86d20..a8e79b4da0 100644 --- a/api/tests/unit_tests/core/helper/test_moderation.py +++ b/api/tests/unit_tests/core/helper/test_moderation.py @@ -1,5 +1,6 @@ from types import SimpleNamespace from typing import cast +from unittest.mock import Mock import pytest from pytest_mock import MockerFixture @@ -68,8 +69,9 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk") - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) + factory = Mock() mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + mocker.patch("core.helper.moderation.create_model_type_instance", return_value=moderation_model) assert ( check_moderation( @@ -119,8 +121,9 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False) - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) + factory = Mock() mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + mocker.patch("core.helper.moderation.create_model_type_instance", return_value=moderation_model) assert ( check_moderation( @@ -147,8 +150,9 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo failing_model = SimpleNamespace( invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")), ) - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model) + factory = Mock() mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + mocker.patch("core.helper.moderation.create_model_type_instance", return_value=failing_model) with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."): check_moderation( diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index e3c1bb3576..87b2e7be87 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from core.plugin.impl.model_runtime_factory import create_model_type_instance from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -107,7 +108,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro def test_model_provider_factory_requires_runtime() -> None: - with pytest.raises(ValueError, match="model_runtime is required"): + with pytest.raises(ValueError, match="runtime is required"): ModelProviderFactory(runtime=None) # type: ignore[arg-type] @@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup def test_model_provider_factory_raises_for_unknown_provider() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime([ + runtime=_FakeModelRuntime([ _build_provider( provider="langgenius/openai/openai", provider_name="openai", @@ -384,7 +385,7 @@ def test_model_provider_factory_builds_model_type_instances( expected_type: type[object], ) -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime([ + runtime=_FakeModelRuntime([ _build_provider( provider="langgenius/openai/openai", provider_name="openai", @@ -394,14 +395,14 @@ def test_model_provider_factory_builds_model_type_instances( ) ) - instance = factory.get_model_type_instance("openai", model_type) + instance = create_model_type_instance(factory=factory, provider="openai", model_type=model_type) assert isinstance(instance, expected_type) def test_model_provider_factory_rejects_unsupported_model_type() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime([ + runtime=_FakeModelRuntime([ _build_provider( provider="langgenius/openai/openai", provider_name="openai", @@ -412,4 +413,4 @@ def test_model_provider_factory_rejects_unsupported_model_type() -> None: ) with pytest.raises(ValueError, match="Unsupported model type: unsupported"): - factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] + create_model_type_instance(factory=factory, provider="openai", model_type="unsupported") # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py index 8e7b33e49a..52da674f06 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -32,5 +32,5 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views(): mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") mock_provider_factory_cls.assert_called_once_with(runtime=runtime) - mock_provider_manager_cls.assert_called_once_with(runtime=runtime) + mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 734f41dbe1..02a435fb04 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -250,7 +250,7 @@ class TestUserAction: ("field_name", "value"), [ ("id", "a" * 21), - ("title", "b" * 21), + ("title", "b" * 101), ], ) def test_user_action_length_limits(self, field_name: str, value: str): diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 991c67b08f..6f7eb5e035 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -15,7 +15,7 @@ from core.app.llm.model_access import ( ) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_runtime from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams @@ -243,7 +243,9 @@ def model_config(monkeypatch): # Create actual provider and model type instances model_provider_factory = ModelProviderFactory(runtime=create_plugin_model_runtime(tenant_id="test")) provider_instance = model_provider_factory.get_model_provider("openai") - model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) + model_type_instance = create_model_type_instance( + factory=model_provider_factory, provider="openai", model_type=ModelType.LLM, + ) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 2857d1b5d6..16b01fed05 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -188,10 +188,16 @@ def test_run_extract_text( if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + if extension: + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", mock_pdf_extract) + else: + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", mock_pdf_extract) elif mime_type.startswith("application/vnd.openxmlformats"): mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + if extension: + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", mock_docx_extract) + else: + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", mock_docx_extract) result = document_extractor_node._run() @@ -439,13 +445,18 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext file.extension = extension file.mime_type = mime_type + extract_patch_target = ( + "graphon.nodes.document_extractor.node._extract_text_by_file_extension" + if extension + else "graphon.nodes.document_extractor.node._extract_text_by_mime_type" + ) with ( patch( "graphon.nodes.document_extractor.node._download_file_content", return_value=b"excel", ), patch( - "graphon.nodes.document_extractor.node._extract_text_from_excel", + extract_patch_target, return_value="excel text", ) as mock_extract, ): @@ -456,7 +467,6 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext ) assert result == "excel text" - mock_extract.assert_called_once_with(b"excel") def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 7bd1d04914..b8e568683b 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution: mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data) mock_node_cls.assert_called_once_with( node_id="n-1", - config=sentinel.node_data, + data=sentinel.node_data, graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value, graph_runtime_state=ANY, runtime=mock_runtime_cls.return_value,