diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 38b87e2cd1..495fd1d898 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -23,7 +23,7 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, @@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import ( ) from graphon.model_runtime.model_providers.base.ai_model import AIModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel): """Attach the already-composed runtime for request-bound call chains.""" self._bound_model_runtime = model_runtime + def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]: + """Resolve a provider factory that stays aligned with the runtime used by the caller.""" + if self._bound_model_runtime is not None: + return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime) + + model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id) + return model_assembly.model_runtime, model_assembly.model_provider_factory + def get_model_provider_factory(self) -> ModelProviderFactory: """Return a provider factory that preserves any request-bound runtime.""" - if self._bound_model_runtime is not None: - return ModelProviderFactory(model_runtime=self._bound_model_runtime) - return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + _, model_provider_factory = self._get_runtime_and_provider_factory() + return model_provider_factory def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None: """ @@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = self.get_model_provider_factory() - - # Get model instance of LLM - return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) + model_runtime, model_provider_factory = self._get_runtime_and_provider_factory() + provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider) + return create_model_type_instance( + runtime=model_runtime, + provider_schema=provider_schema, + model_type=model_type, + ) def get_model_schema( self, model_type: ModelType, model: str, credentials: dict[str, Any] | None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index f169f247cf..18b9b72e9d 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,7 +4,7 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from extensions.ext_hosting_provider import hosting_configuration from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeBadRequestError @@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) - - # Get model instance of LLM - model_type_instance = model_provider_factory.get_model_type_instance( + model_assembly = create_plugin_model_assembly(tenant_id=tenant_id) + model_type_instance = model_assembly.create_model_type_instance( provider=openai_provider_name, model_type=ModelType.MODERATION ) model_type_instance = cast(ModerationModel, model_type_instance) diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 4e66d58b5e..a68220e682 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -20,7 +20,7 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelTy from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 35abd2ae8c..98a5660fdf 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -3,13 +3,46 @@ from __future__ import annotations from typing import TYPE_CHECKING from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.protocols.runtime import ModelRuntime if TYPE_CHECKING: from core.model_manager import ModelManager from core.plugin.impl.model_runtime import PluginModelRuntime from core.provider_manager import ProviderManager +_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = { + ModelType.LLM: LargeLanguageModel, + ModelType.TEXT_EMBEDDING: TextEmbeddingModel, + ModelType.RERANK: RerankModel, + ModelType.SPEECH2TEXT: Speech2TextModel, + ModelType.MODERATION: ModerationModel, + ModelType.TTS: TTSModel, +} + + +def create_model_type_instance( + *, + runtime: ModelRuntime, + provider_schema: ProviderEntity, + model_type: ModelType, +) -> AIModel: + """Build the graphon model wrapper explicitly against the request runtime.""" + model_class = _MODEL_CLASS_BY_TYPE.get(model_type) + if model_class is None: + raise ValueError(f"Unsupported model type: {model_type}") + + return model_class(provider_schema=provider_schema, model_runtime=runtime) + class PluginModelAssembly: """Compose request-scoped model views on top of a single plugin runtime.""" @@ -38,9 +71,22 @@ class PluginModelAssembly: @property def model_provider_factory(self) -> ModelProviderFactory: if self._model_provider_factory is None: - self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime) return self._model_provider_factory + def create_model_type_instance( + self, + *, + provider: str, + model_type: ModelType, + ) -> AIModel: + provider_schema = self.model_provider_factory.get_provider_schema(provider=provider) + return create_model_type_instance( + runtime=self.model_runtime, + provider_schema=provider_schema, + model_type=model_type, + ) + @property def provider_manager(self) -> ProviderManager: if self._provider_manager is None: @@ -79,6 +125,20 @@ def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory +def create_plugin_model_type_instance( + *, + tenant_id: str, + provider: str, + model_type: ModelType, + user_id: str | None = None, +) -> AIModel: + """Create a tenant-bound model wrapper for the requested provider and model type.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).create_model_type_instance( + provider=provider, + model_type=model_type, + ) + + def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: """Create a tenant-bound provider manager for service flows.""" return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index b290ae456e..9faa70a0b8 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID from services.feature_service import FeatureService if TYPE_CHECKING: - from graphon.model_runtime.runtime import ModelRuntime + from graphon.model_runtime.protocols.runtime import ModelRuntime _credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) @@ -165,7 +165,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) + model_provider_factory = ModelProviderFactory(runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -362,7 +362,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) + model_provider_factory = ModelProviderFactory(runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 895953a3c1..bdef87abe7 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory): # Re-validate using the resolved node class so workflow-local node schemas # stay explicit and constructors receive the concrete typed payload. resolved_node_data = self._validate_resolved_node_data(node_class, node_data) - config_for_node_init: BaseNodeData | dict[str, Any] - if isinstance(resolved_node_data, BaseNodeData): - config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True) - else: - config_for_node_init = resolved_node_data node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -448,7 +443,7 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( node_id=node_id, - config=config_for_node_init, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 68a24e86b1..17d71668cb 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, node_id: str, - config: AgentNodeData, + data: AgentNodeData, *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, @@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]): ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index f3006c4242..a4ef3d1ea7 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, node_id: str, - config: DatasourceNodeData, + data: DatasourceNodeData, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 9c1b7ab2c4..1d60f530a1 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, node_id: str, - config: KnowledgeIndexNodeData, + data: KnowledgeIndexNodeData, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 25f73e446d..1aba2737b0 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, node_id: str, - config: KnowledgeRetrievalNodeData, + data: KnowledgeRetrievalNodeData, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/pyproject.toml b/api/pyproject.toml index 69add5c68d..8d377c1fe4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ # Emerging: newer and fast-moving, use compatible pins "fastopenapi[flask]~=0.7.0", - "graphon~=0.2.2", + "graphon~=0.3.0", "httpx-sse~=0.4.0", "json-repair~=0.59.4", ] diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index f97b85dc2b..b8c2ed5e6f 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1251,7 +1251,7 @@ class WorkflowService: node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( node_id=node_config["id"], - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=DifyHumanInputNodeRuntime(run_context), diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 2392084c36..c52ffd1f70 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -71,7 +71,7 @@ def test_node_integration_minimal_stream(mocker): node = DatasourceNode( node_id="n", - config=DatasourceNodeData( + data=DatasourceNodeData( type="datasource", version="1", title="Datasource", diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index a9a2617bae..a77fe5970a 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType @@ -15,8 +15,9 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") - model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) + model_assembly = create_plugin_model_assembly(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = model_assembly.model_provider_factory + model_type_instance = model_assembly.create_model_type_instance(provider=provider, model_type=ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index aaa6092993..86ebb8800a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -66,7 +66,7 @@ def init_code_node(code_config: dict): node = CodeNode( node_id=str(uuid.uuid4()), - config=CodeNodeData.model_validate(code_config["data"]), + data=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b9f7b9575b..e5f1626072 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -76,7 +76,7 @@ def init_http_node(config: dict): node = HttpRequestNode( node_id=str(uuid.uuid4()), - config=HttpRequestNodeData.model_validate(config["data"]), + data=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -724,7 +724,7 @@ def test_nested_object_variable_selector(setup_http_mock): node = HttpRequestNode( node_id=str(uuid.uuid4()), - config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), + data=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 3eead70163..cc37a1ff1a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -77,7 +77,7 @@ def init_llm_node(config: dict) -> LLMNode: node = LLMNode( node_id=str(uuid.uuid4()), - config=LLMNodeData.model_validate(config["data"]), + data=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index f2eabb86c3..5de0167a6f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -71,7 +71,7 @@ def init_parameter_extractor_node(config: dict, memory=None): node = ParameterExtractorNode( node_id=str(uuid.uuid4()), - config=ParameterExtractorNodeData.model_validate(config["data"]), + data=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index e2e0723fb8..29737afa4f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -88,7 +88,7 @@ def test_execute_template_transform(): node = TemplateTransformNode( node_id=str(uuid.uuid4()), - config=TemplateTransformNodeData.model_validate(config["data"]), + data=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index a8e9422c1e..1c49bbdb97 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -62,7 +62,7 @@ def init_tool_node(config: dict): node = ToolNode( node_id=str(uuid.uuid4()), - config=ToolNodeData.model_validate(config["data"]), + data=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index bd13527e14..66b3392a4b 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -210,7 +210,9 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) + variable_pool = VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id=execution_id) + ) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 5aed230cd4..1c5669ba94 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -102,7 +102,7 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( node_id="start", - config=start_data, + data=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -117,7 +117,7 @@ def _build_graph( ) human_node = HumanInputNode( node_id="human", - config=human_data, + data=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -131,7 +131,7 @@ def _build_graph( ) end_node = EndNode( node_id="end", - config=end_data, + data=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 64bcfa9a18..9847f513e4 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -372,7 +372,9 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -583,7 +585,9 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) @@ -617,7 +621,9 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 6104b8d6ca..4a6b32f090 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -58,7 +58,7 @@ class _StubToolNode(Node[_StubToolNodeData]): def __init__( self, node_id: str, - config: _StubToolNodeData, + data: _StubToolNodeData, *, graph_init_params, graph_runtime_state, @@ -66,7 +66,7 @@ class _StubToolNode(Node[_StubToolNodeData]): ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) @@ -167,7 +167,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 58c7bfa4bc..700ccd8c50 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -54,7 +54,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) @@ -93,7 +93,7 @@ class TestWorkflowBasedAppRunner: def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) @@ -162,7 +162,7 @@ class TestWorkflowBasedAppRunner: app_id="app", ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) @@ -241,7 +241,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -284,7 +284,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) @@ -423,7 +423,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 0bcc1029b0..52a3cdb159 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -283,7 +283,9 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -725,7 +727,9 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) @@ -753,7 +757,9 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -769,7 +775,9 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index d3bd15b6f3..320a3bc42c 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -21,7 +21,9 @@ class TestTriggerPostLayer: ) runtime_state = SimpleNamespace( outputs={"answer": "ok"}, - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-1") + ), total_tokens=12, ) @@ -60,7 +62,9 @@ class TestTriggerPostLayer: def test_on_event_handles_missing_trigger_log(self): runtime_state = SimpleNamespace( outputs={}, - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-1") + ), total_tokens=0, ) @@ -91,7 +95,9 @@ class TestTriggerPostLayer: def test_on_event_ignores_non_status_events(self): runtime_state = SimpleNamespace( outputs={}, - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-1") + ), total_tokens=0, ) diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index cacb4dd4fa..12c8eb08d8 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -60,7 +60,10 @@ def _make_layer( workflow_execution_id="run-id", conversation_id="conv-id", ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool.from_bootstrap(system_variables=system_variables), + start_at=0.0, + ) read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) application_generate_entity = WorkflowAppGenerateEntity.model_construct( diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index a28143026f..1b714d6830 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -354,7 +354,8 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( @@ -379,7 +380,10 @@ def test_validate_provider_credentials_without_credential_id() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -426,23 +430,37 @@ def test_switch_preferred_provider_type_creates_record_when_missing() -> None: def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: configuration = _build_provider_configuration() - mock_factory = Mock() mock_model_type_instance = Mock() mock_schema = _build_ai_model("gpt-4o") - mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = configuration.provider mock_factory.get_model_schema.return_value = mock_schema + mock_assembly = Mock() + mock_assembly.model_runtime = Mock() + mock_assembly.model_provider_factory = mock_factory - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", - return_value=mock_factory, - ) as mock_factory_builder: + with ( + patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=mock_assembly, + ) as mock_assembly_builder, + patch( + "core.entities.provider_configuration.create_model_type_instance", + return_value=mock_model_type_instance, + ) as mock_model_builder, + ): model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema - assert mock_factory_builder.call_count == 2 - mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) + assert mock_assembly_builder.call_count == 2 + mock_factory.get_provider_schema.assert_called_once_with(provider="openai") + mock_model_builder.assert_called_once_with( + runtime=mock_assembly.model_runtime, + provider_schema=configuration.provider, + model_type=ModelType.LLM, + ) mock_factory.get_model_schema.assert_called_once_with( provider="openai", model_type=ModelType.LLM, @@ -456,17 +474,21 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non bound_runtime = Mock() configuration.bind_model_runtime(bound_runtime) - mock_factory = Mock() mock_model_type_instance = Mock() mock_schema = _build_ai_model("gpt-4o") - mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = configuration.provider mock_factory.get_model_schema.return_value = mock_schema with ( patch( "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory ) as mock_factory_cls, - patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder, + patch( + "core.entities.provider_configuration.create_model_type_instance", + return_value=mock_model_type_instance, + ) as mock_model_builder, ): model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) @@ -474,8 +496,14 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema assert mock_factory_cls.call_count == 2 - mock_factory_cls.assert_called_with(model_runtime=bound_runtime) - mock_factory_builder.assert_not_called() + mock_factory_cls.assert_called_with(runtime=bound_runtime) + mock_assembly_builder.assert_not_called() + mock_factory.get_provider_schema.assert_called_once_with(provider="openai") + mock_model_builder.assert_called_once_with( + runtime=bound_runtime, + provider_schema=configuration.provider, + model_type=ModelType.LLM, + ) def test_get_provider_model_returns_none_when_model_not_found() -> None: @@ -504,7 +532,10 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), + ): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -722,7 +753,8 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( @@ -1069,7 +1101,8 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): @@ -1083,7 +1116,10 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory2 = Mock() mock_factory2.model_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory2), + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1575,7 +1611,8 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( @@ -1701,7 +1738,8 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( diff --git a/api/tests/unit_tests/core/helper/test_moderation.py b/api/tests/unit_tests/core/helper/test_moderation.py index a0dfa86d20..c33002329b 100644 --- a/api/tests/unit_tests/core/helper/test_moderation.py +++ b/api/tests/unit_tests/core/helper/test_moderation.py @@ -68,8 +68,8 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk") - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) - mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly) assert ( check_moderation( @@ -91,7 +91,7 @@ def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture) provider_map={openai_provider: hosting_openai}, ), ) - factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory") + factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_assembly") choice_mock = mocker.patch("core.helper.moderation.secrets.choice") assert ( @@ -119,8 +119,8 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False) - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) - mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly) assert ( check_moderation( @@ -147,8 +147,8 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo failing_model = SimpleNamespace( invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")), ) - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model) - mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: failing_model) + mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly) with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."): check_moderation( diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index c4fd970562..2b51dc8182 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from core.plugin.impl.model_runtime_factory import create_model_type_instance from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -73,7 +74,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None: supported_model_types=[ModelType.LLM], configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], ) - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider])) provider_schema = factory.get_model_provider("openai") @@ -98,7 +99,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], ), ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) provider_schema = factory.get_model_provider("openai") @@ -107,8 +108,8 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro def test_model_provider_factory_requires_runtime() -> None: - with pytest.raises(ValueError, match="model_runtime is required"): - ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] + with pytest.raises(ValueError, match="runtime is required"): + ModelProviderFactory(runtime=None) # type: ignore[arg-type] def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: @@ -119,7 +120,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non supported_model_types=[ModelType.LLM], ) ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) result = factory.get_providers() @@ -133,7 +134,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup provider_name="openai", supported_model_types=[ModelType.LLM], ) - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider])) result = factory.get_provider_schema("openai") @@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup def test_model_provider_factory_raises_for_unknown_provider() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( + runtime=_FakeModelRuntime( [ _build_provider( provider="langgenius/openai/openai", @@ -172,7 +173,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() -> models=[_build_model("rerank-v3", ModelType.RERANK)], ), ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) results = factory.get_models(provider="openai", model_type=ModelType.LLM) @@ -196,7 +197,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], ), ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) results = factory.get_models(model_type=ModelType.TTS) @@ -214,7 +215,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], ) ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) results = factory.get_models(provider="openai") @@ -242,7 +243,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None: ) ] ) - factory = ModelProviderFactory(model_runtime=runtime) + factory = ModelProviderFactory(runtime=runtime) filtered = factory.provider_credentials_validate( provider="openai", @@ -258,7 +259,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None: def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( + runtime=_FakeModelRuntime( [ _build_provider( provider="langgenius/openai/openai", @@ -294,7 +295,7 @@ def test_model_provider_factory_validates_model_credentials() -> None: ) ] ) - factory = ModelProviderFactory(model_runtime=runtime) + factory = ModelProviderFactory(runtime=runtime) filtered = factory.model_credentials_validate( provider="openai", @@ -314,7 +315,7 @@ def test_model_provider_factory_validates_model_credentials() -> None: def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( + runtime=_FakeModelRuntime( [ _build_provider( provider="langgenius/openai/openai", @@ -346,7 +347,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider ) runtime.get_model_schema.return_value = "schema" runtime.get_provider_icon.return_value = (b"icon", "image/png") - factory = ModelProviderFactory(model_runtime=runtime) + factory = ModelProviderFactory(runtime=runtime) assert ( factory.get_model_schema( @@ -382,39 +383,43 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider (ModelType.TTS, TTSModel), ], ) -def test_model_provider_factory_builds_model_type_instances( +def test_create_model_type_instance_builds_model_wrappers( model_type: ModelType, expected_type: type[object], ) -> None: - factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( - [ - _build_provider( - provider="langgenius/openai/openai", - provider_name="openai", - supported_model_types=[model_type], - ) - ] - ) + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[model_type], + ) + ] ) - instance = factory.get_model_type_instance("openai", model_type) + instance = create_model_type_instance( + runtime=runtime, + provider_schema=runtime.fetch_model_providers()[0], + model_type=model_type, + ) assert isinstance(instance, expected_type) -def test_model_provider_factory_rejects_unsupported_model_type() -> None: - factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( - [ - _build_provider( - provider="langgenius/openai/openai", - provider_name="openai", - supported_model_types=[ModelType.LLM], - ) - ] - ) +def test_create_model_type_instance_rejects_unsupported_model_type() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] ) with pytest.raises(ValueError, match="Unsupported model type: unsupported"): - factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] + create_model_type_instance( + runtime=runtime, + provider_schema=runtime.fetch_model_providers()[0], + model_type="unsupported", # type: ignore[arg-type] + ) diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py index 7491e79f30..52da674f06 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views(): assert assembly.model_manager is model_manager mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") - mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_factory_cls.assert_called_once_with(runtime=runtime) mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 02f12fb3b4..e84fcba3d9 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc result = manager.get_default_model("tenant-id", ModelType.LLM) - mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime) assert result is not None assert result.model == "gpt-4" assert result.provider.provider == "openai" @@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock result = manager.get_configurations("tenant-id") expected_alias = str(ModelProviderID("openai")) - mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime) assert result.tenant_id == "tenant-id" assert expected_alias in provider_records assert expected_alias in provider_model_records @@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF assert first is second mock_get_all_providers.assert_called_once_with("tenant-id") - mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime) mock_provider_configuration.assert_called_once() provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 9f3e3b00b9..c721c7b0eb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -96,7 +96,7 @@ class MockNodeFactory(DifyNodeFactory): if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -106,7 +106,7 @@ class MockNodeFactory(DifyNodeFactory): elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -122,7 +122,7 @@ class MockNodeFactory(DifyNodeFactory): }: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -132,7 +132,7 @@ class MockNodeFactory(DifyNodeFactory): else: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index f9819c47ec..e0eb4e7361 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -56,7 +56,7 @@ class MockNodeMixin: def __init__( self, node_id: str, - config: Any, + data: Any, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", @@ -98,7 +98,7 @@ class MockNodeMixin: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, **kwargs, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 75bc6d05f7..6156f7b576 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -111,7 +111,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -140,7 +140,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( node_id=start_config["id"], - config=StartNodeData(title="Start", variables=[]), + data=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -155,7 +155,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( node_id=human_a_config["id"], - config=human_data, + data=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -165,7 +165,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( node_id=human_b_config["id"], - config=human_data, + data=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -183,7 +183,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( node_id=end_config["id"], - config=end_data, + data=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index ae9dae0646..387f508154 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -48,7 +48,7 @@ def test_execute_answer(): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -71,7 +71,7 @@ def test_execute_answer(): node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=AnswerNodeData( + data=AnswerNodeData( title="123", type="answer", answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index d7ef781732..86ec730d7f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -79,7 +79,7 @@ def test_datasource_node_delegates_to_manager_stream(mocker): node = DatasourceNode( node_id="n", - config=DatasourceNodeData( + data=DatasourceNodeData( type="datasource", version="1", title="Datasource", diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index be7cc073db..796fc7719d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -29,7 +29,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -85,7 +85,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -143,7 +143,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -230,7 +230,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -536,7 +536,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): # UUID that triggers the json_repair truncation bug test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -583,7 +583,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): """ test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -624,7 +624,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 2e89a2da3c..afde541beb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -110,12 +110,15 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="user", files=[]), + user_inputs={}, + ), start_at=time.perf_counter(), ) return HttpRequestNode( node_id="http-node", - config=HttpRequestNodeData.model_validate(node_data), + data=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 0659984c76..715292b85c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -149,7 +149,7 @@ def _build_human_input_node( ) return HumanInputNode( node_id=node_id, - config=typed_node_data, + data=typed_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=runtime, @@ -241,16 +241,16 @@ class TestUserAction: def test_user_action_length_boundaries(self): """Test user action id and title length boundaries.""" - action = UserAction(id="a" * 20, title="b" * 20) + action = UserAction(id="a" * 20, title="b" * 100) assert action.id == "a" * 20 - assert action.title == "b" * 20 + assert action.title == "b" * 100 @pytest.mark.parametrize( ("field_name", "value"), [ ("id", "a" * 21), - ("title", "b" * 21), + ("title", "b" * 101), ], ) def test_user_action_length_limits(self, field_name: str, value: str): @@ -427,7 +427,7 @@ class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -504,7 +504,7 @@ class TestHumanInputNodeVariableResolution: assert params.resolved_default_values == expected_values def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -565,7 +565,7 @@ class TestHumanInputNodeVariableResolution: assert not hasattr(pause_event.reason, "form_token") def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -631,7 +631,7 @@ class TestHumanInputNodeVariableResolution: assert params.display_in_ui is True def test_debugger_debug_mode_overrides_email_recipients(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user-123", app_id="app", @@ -748,7 +748,7 @@ class TestHumanInputNodeRenderedContent: """Tests for rendering submitted content.""" def test_replaces_outputs_placeholders_after_submission(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 4a9438b14f..741b104393 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -40,7 +40,7 @@ def _create_human_input_node( ) return HumanInputNode( node_id=config["id"], - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, @@ -51,7 +51,11 @@ def _create_human_input_node( def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + variable_pool=VariablePool.from_bootstrap( + system_variables=system_variables, + user_inputs={}, + environment_variables=[], + ), start_at=0.0, ) graph_init_params = GraphInitParams( @@ -114,7 +118,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# def _build_timeout_node() -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + variable_pool=VariablePool.from_bootstrap( + system_variables=system_variables, + user_inputs={}, + environment_variables=[], + ), start_at=0.0, ) graph_init_params = GraphInitParams( diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 8ffce39cd6..18ed7a0b1d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -32,7 +32,7 @@ class _MissingGraphBuilder: def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) @@ -46,7 +46,7 @@ def _build_iteration_node( init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( node_id="iteration-node", - config=IterationNodeData( + data=IterationNodeData( type="iteration", title="Iteration", iterator_selector=["start", "items"], diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index f254fc3d09..f43ed91e43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -40,7 +40,7 @@ def mock_graph_init_params(): @pytest.fixture def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], @@ -102,7 +102,7 @@ def _build_node( ) -> KnowledgeIndexNode: return KnowledgeIndexNode( node_id=node_id, - config=( + data=( node_data if isinstance(node_data, KnowledgeIndexNodeData) else KnowledgeIndexNodeData.model_validate(node_data) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index e923ee761b..705a25c78b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -46,7 +46,7 @@ def mock_graph_init_params(): @pytest.fixture def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], @@ -117,7 +117,7 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -146,7 +146,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -205,7 +205,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -249,7 +249,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -285,7 +285,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -320,7 +320,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -361,7 +361,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -400,7 +400,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -481,7 +481,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -518,7 +518,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -573,7 +573,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -621,7 +621,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -682,7 +682,7 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 388654f279..20b94d5d50 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -16,10 +16,10 @@ class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" @staticmethod - def _build_node(*, config, graph_init_params, graph_runtime_state): + def _build_node(*, data, graph_init_params, graph_runtime_state): return ListOperatorNode( node_id="test", - config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + data=data if isinstance(data, ListOperatorNodeData) else ListOperatorNodeData.model_validate(data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) @@ -65,7 +65,7 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable return self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -83,7 +83,7 @@ class TestListOperatorNode: } node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -127,7 +127,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -153,7 +153,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -177,7 +177,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -201,7 +201,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -228,7 +228,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -255,7 +255,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,7 +282,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -312,7 +312,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -335,7 +335,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -359,7 +359,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,7 +384,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -408,7 +408,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -432,7 +432,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -456,7 +456,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -483,7 +483,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index c707cf28cd..50a69d3126 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -15,7 +15,7 @@ from core.app.llm.model_access import ( ) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams @@ -187,7 +187,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -208,7 +208,7 @@ def llm_node( http_client = mock.MagicMock() node = LLMNode( node_id="1", - config=llm_node_data, + data=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -241,9 +241,10 @@ def model_config(monkeypatch): ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + model_assembly = create_plugin_model_assembly(tenant_id="test") + model_provider_factory = model_assembly.model_provider_factory provider_instance = model_provider_factory.get_model_provider("openai") - model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) + model_type_instance = model_assembly.create_model_type_instance(provider="openai", model_type=ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( @@ -1173,7 +1174,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat http_client = mock.MagicMock() node = LLMNode( node_id="1", - config=llm_node_data, + data=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 892f6cc586..dd57dde1fe 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -28,7 +28,7 @@ def _build_template_transform_node( ) return TemplateTransformNode( node_id=node_id, - config=typed_node_data, + data=typed_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, **kwargs, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index a846efbb43..c25ac7da0f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -39,7 +39,7 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( node_id="test_node", - config=TemplateTransformNodeData( + data=TemplateTransformNodeData( title="Template Transform", type="template-transform", variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 364408ead6..a05151f79b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -35,7 +35,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="user", files=[]), + user_inputs={}, + ), start_at=0.0, ) return init_params, runtime_state @@ -62,7 +65,7 @@ def test_node_hydrates_data_during_initialization(): node = _SampleNode( node_id="node-1", - config=_build_node_data(), + data=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -82,13 +85,16 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="user", files=[]), + user_inputs={}, + ), start_at=0.0, ) node = _SampleNode( node_id="node-1", - config=_build_node_data(), + data=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -140,7 +146,7 @@ def test_node_hydration_preserves_compatibility_extra_fields(): node = _SampleNode( node_id="node-1", - config=node_config["data"], + data=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index dd75b32593..4c67f3fb02 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -49,7 +49,7 @@ def document_extractor_node(graph_init_params): http_client = Mock() node = DocumentExtractorNode( node_id="test_node_id", - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -186,12 +186,13 @@ def test_run_extract_text( monkeypatch.setattr("graphon.file.file_manager.download", mock_download) + dispatch_mock = None if mime_type == "application/pdf": - mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + dispatch_mock = Mock(return_value=expected_text[0]) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", dispatch_mock) elif mime_type.startswith("application/vnd.openxmlformats"): - mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + dispatch_mock = Mock(return_value=expected_text[0]) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", dispatch_mock) result = document_extractor_node._run() @@ -200,6 +201,19 @@ def test_run_extract_text( assert result.outputs is not None assert result.outputs["text"] == ArrayStringSegment(value=expected_text) + if mime_type == "application/pdf": + dispatch_mock.assert_called_once_with( + file_content=file_content, + file_extension=extension, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + elif mime_type.startswith("application/vnd.openxmlformats"): + dispatch_mock.assert_called_once_with( + file_content=file_content, + mime_type=mime_type, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + if transfer_method == FileTransferMethod.REMOTE_URL: document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt") elif transfer_method == FileTransferMethod.LOCAL_FILE: @@ -439,24 +453,42 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext file.extension = extension file.mime_type = mime_type - with ( - patch( - "graphon.nodes.document_extractor.node._download_file_content", - return_value=b"excel", - ), - patch( - "graphon.nodes.document_extractor.node._extract_text_from_excel", - return_value="excel text", - ) as mock_extract, + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", ): - result = _extract_text_from_file( - document_extractor_node.http_client, - file, - unstructured_api_config=document_extractor_node._unstructured_api_config, - ) + if extension: + with patch( + "graphon.nodes.document_extractor.node._extract_text_by_file_extension", + return_value="excel text", + ) as mock_extract: + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + mock_extract.assert_called_once_with( + file_content=b"excel", + file_extension=extension, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + else: + with patch( + "graphon.nodes.document_extractor.node._extract_text_by_mime_type", + return_value="excel text", + ) as mock_extract: + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + mock_extract.assert_called_once_with( + file_content=b"excel", + mime_type=mime_type, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) assert result == "excel text" - mock_extract.assert_called_once_with(b"excel") def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index aa9a1360b0..5965645c4f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -29,7 +29,7 @@ def _build_if_else_node( node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + data=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), ) @@ -48,7 +48,10 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="aaa", files=[]), + user_inputs={}, + ) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -148,7 +151,7 @@ def test_execute_if_else_result_false(): ) # construct variable pool - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -305,7 +308,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): ) # construct variable pool with boolean values - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) @@ -359,7 +362,7 @@ def test_execute_if_else_boolean_false_conditions(): ) # construct variable pool with boolean values - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) @@ -424,7 +427,7 @@ def test_execute_if_else_boolean_cases_structure(): ) # construct variable pool with boolean values - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 465a4c0ff4..1b4cecc757 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -22,7 +22,7 @@ from graphon.variables import ArrayFileSegment def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: return ListOperatorNode( node_id="test_node_id", - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=MagicMock(), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 5655f80737..f890f79511 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -31,7 +31,7 @@ def make_start_node(user_inputs, variables): return StartNode( node_id="start", - config=node_data, + data=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -260,7 +260,7 @@ def test_start_node_outputs_full_variable_pool_snapshot(): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( node_id="start", - config=node_data, + data=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 284af68319..4aa5803ac7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -99,7 +99,7 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) + variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] @@ -110,7 +110,7 @@ def tool_node(monkeypatch) -> ToolNode: node = ToolNode( node_id="node-instance", - config=ToolNodeData.model_validate(config["data"]), + data=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index e3b5e3b591..c5ac8d2ce2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -44,7 +44,7 @@ def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( node_id="node-1", - config=_build_node_data(), + data=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 07d03bec05..fccb5ab1c3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -52,7 +52,7 @@ def create_webhook_node( node = TriggerWebhookNode( node_id="webhook-node-1", - config=webhook_data, + data=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index b839490d3c..c5ae542d8b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -44,7 +44,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) ) node = TriggerWebhookNode( node_id="1", - config=webhook_data, + data=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 041c5cc612..a94f2a807a 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -55,7 +55,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_to_variable_pool_with_system_variables(self): """Test mapping system variables from user inputs to variable pool.""" # Initialize variable pool with system variables - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="test_user_id", app_id="test_app_id", @@ -128,7 +128,7 @@ class TestWorkflowEntry: return NodeConfigDictAdapter.validate_python(node_config) workflow = StubWorkflow() - variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={}) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}) expected_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, @@ -157,7 +157,7 @@ class TestWorkflowEntry: """Test mapping environment variables from user inputs to variable pool.""" # Initialize variable pool with environment variables env_var = StringVariable(name="API_KEY", value="existing_key") - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), environment_variables=[env_var], user_inputs={}, @@ -198,7 +198,7 @@ class TestWorkflowEntry: """Test mapping conversation variables from user inputs to variable pool.""" # Initialize variable pool with conversation variables conv_var = StringVariable(name="last_message", value="Hello") - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), conversation_variables=[conv_var], user_inputs={}, @@ -239,7 +239,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self): """Test mapping regular node variables from user inputs to variable pool.""" # Initialize empty variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -281,7 +281,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_file_handling(self): """Test mapping file inputs from user inputs to variable pool.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -340,7 +340,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_missing_variable_error(self): """Test that mapping raises error when required variable is missing.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -366,7 +366,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_alternative_key_format(self): """Test mapping with alternative key format (without node prefix).""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -396,7 +396,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_complex_selectors(self): """Test mapping with complex node variable keys.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -432,7 +432,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_invalid_node_variable(self): """Test that mapping handles invalid node variable format.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -463,7 +463,7 @@ class TestWorkflowEntry: env_var = StringVariable(name="API_KEY", value="existing_key") conv_var = StringVariable(name="session_id", value="session123") - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="test_user", app_id="test_app", diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 270d0bf90d..c830255926 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -165,7 +165,7 @@ class TestWorkflowChildEngineBuilder: _ = graph_config node = node_cls( node_id=root_node_id, - config=BaseNodeData( + data=BaseNodeData( type=node_cls.node_type, title="Child Model", ), @@ -334,7 +334,7 @@ class TestWorkflowEntrySingleStepRun: def extract_variable_selector_to_variable_mapping(**_kwargs): return {} - variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={}) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}) variable_loader = MagicMock() variable_loader.load_variables.return_value = [ StringVariable( diff --git a/api/uv.lock b/api/uv.lock index 6f75c9f6fe..42c4e05d26 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1594,7 +1594,7 @@ requires-dist = [ { name = "gmpy2", specifier = ">=2.3.0" }, { name = "google-api-python-client", specifier = ">=2.195.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.149.0,<2.0.0" }, - { name = "graphon", specifier = "~=0.2.2" }, + { name = "graphon", specifier = "~=0.3.0" }, { name = "gunicorn", specifier = ">=25.3.0" }, { name = "httpx", extras = ["socks"], specifier = ">=0.28.1,<1.0.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, @@ -2937,7 +2937,7 @@ httpx = [ [[package]] name = "graphon" -version = "0.2.2" +version = "0.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "charset-normalizer" }, @@ -2958,9 +2958,9 @@ dependencies = [ { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "webvtt-py" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/08/50/e745a79c5f742f88f6011a1f7c9ba2c2f9cc1beedd982f0b192f1ab8c748/graphon-0.2.2.tar.gz", hash = "sha256:141f0de536171850f1af6f738dc66f0285aadd3c097f1dad2a038636789e0aa5", size = 236360, upload-time = "2026-04-17T08:52:28.047Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/62/83593d6e7a139ff124711ea05882cadca7065c11a38763aa9360d7e76804/graphon-0.3.0.tar.gz", hash = "sha256:cd38f842ae3dcfa956428b952efbe2a3ea9c1581446647142accbbdeb638b876", size = 241176, upload-time = "2026-04-21T15:18:48.291Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/de/89/a6340afdaf5169d17a318e00fc685fb67ed99baa602c2cbbbf6af6a76096/graphon-0.2.2-py3-none-any.whl", hash = "sha256:754e544d08779138f99eac6547ab08559463680e2c76488b05e1c978210392b4", size = 340808, upload-time = "2026-04-17T08:52:26.5Z" }, + { url = "https://files.pythonhosted.org/packages/b3/f7/81ee8f0368aa6a2d47f97fecc5d4a12865c987906798cbddd0e3b8387f33/graphon-0.3.0-py3-none-any.whl", hash = "sha256:9cca45ebab2a79fd4d04432f55b5b962e9e4f34fa037cc20fee7f18ec80eaa5d", size = 348486, upload-time = "2026-04-21T15:18:46.737Z" }, ] [[package]]