chore(api): upgrade graphon to v0.3.0

Adapt the backend Graphon integration to the v0.3.0 breaking changes.

Migrate provider factory and runtime usage, switch workflow node construction to the new data payload API, and refresh backend tests for the updated VariablePool and node behaviors.
This commit is contained in:
-LAN- 2026-04-22 00:02:42 +08:00
parent 29f34848cd
commit 05aba26091
61 changed files with 470 additions and 285 deletions

View File

@ -23,7 +23,7 @@ from core.entities.provider_entities import (
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import (
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel):
"""Attach the already-composed runtime for request-bound call chains."""
self._bound_model_runtime = model_runtime
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
if self._bound_model_runtime is not None:
return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime)
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
return model_assembly.model_runtime, model_assembly.model_provider_factory
def get_model_provider_factory(self) -> ModelProviderFactory:
"""Return a provider factory that preserves any request-bound runtime."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
_, model_provider_factory = self._get_runtime_and_provider_factory()
return model_provider_factory
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
"""
@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
model_provider_factory = self.get_model_provider_factory()
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
return create_model_type_instance(
runtime=model_runtime,
provider_schema=provider_schema,
model_type=model_type,
)
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None

View File

@ -4,7 +4,7 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
text_chunk = secrets.choice(text_chunks)
try:
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = model_provider_factory.get_model_type_instance(
model_assembly = create_plugin_model_assembly(tenant_id=tenant_id)
model_type_instance = model_assembly.create_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)

View File

@ -20,7 +20,7 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelTy
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)

View File

@ -3,13 +3,46 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.protocols.runtime import ModelRuntime
if TYPE_CHECKING:
from core.model_manager import ModelManager
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
ModelType.LLM: LargeLanguageModel,
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
ModelType.RERANK: RerankModel,
ModelType.SPEECH2TEXT: Speech2TextModel,
ModelType.MODERATION: ModerationModel,
ModelType.TTS: TTSModel,
}
def create_model_type_instance(
*,
runtime: ModelRuntime,
provider_schema: ProviderEntity,
model_type: ModelType,
) -> AIModel:
"""Build the graphon model wrapper explicitly against the request runtime."""
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
if model_class is None:
raise ValueError(f"Unsupported model type: {model_type}")
return model_class(provider_schema=provider_schema, model_runtime=runtime)
class PluginModelAssembly:
"""Compose request-scoped model views on top of a single plugin runtime."""
@ -38,9 +71,22 @@ class PluginModelAssembly:
@property
def model_provider_factory(self) -> ModelProviderFactory:
if self._model_provider_factory is None:
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
return self._model_provider_factory
def create_model_type_instance(
self,
*,
provider: str,
model_type: ModelType,
) -> AIModel:
provider_schema = self.model_provider_factory.get_provider_schema(provider=provider)
return create_model_type_instance(
runtime=self.model_runtime,
provider_schema=provider_schema,
model_type=model_type,
)
@property
def provider_manager(self) -> ProviderManager:
if self._provider_manager is None:
@ -79,6 +125,20 @@ def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory
def create_plugin_model_type_instance(
*,
tenant_id: str,
provider: str,
model_type: ModelType,
user_id: str | None = None,
) -> AIModel:
"""Create a tenant-bound model wrapper for the requested provider and model type."""
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).create_model_type_instance(
provider=provider,
model_type=model_type,
)
def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager:
"""Create a tenant-bound provider manager for service flows."""
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager

View File

@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService
if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
@ -165,7 +165,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@ -362,7 +362,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(

View File

@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory):
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
config_for_node_init: BaseNodeData | dict[str, Any]
if isinstance(resolved_node_data, BaseNodeData):
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
else:
config_for_node_init = resolved_node_data
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@ -448,7 +443,7 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
node_id=node_id,
config=config_for_node_init,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,

View File

@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
node_id: str,
config: AgentNodeData,
data: AgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
node_id: str,
config: DatasourceNodeData,
data: DatasourceNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
node_id: str,
config: KnowledgeIndexNodeData,
data: KnowledgeIndexNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
node_id: str,
config: KnowledgeRetrievalNodeData,
data: KnowledgeRetrievalNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)

View File

@ -45,7 +45,7 @@ dependencies = [
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
"graphon~=0.2.2",
"graphon~=0.3.0",
"httpx-sse~=0.4.0",
"json-repair~=0.59.4",
]

View File

@ -1251,7 +1251,7 @@ class WorkflowService:
node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
node_id=node_config["id"],
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),

View File

@ -71,7 +71,7 @@ def test_node_integration_minimal_stream(mocker):
node = DatasourceNode(
node_id="n",
config=DatasourceNodeData(
data=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",

View File

@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from graphon.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderType
@ -15,8 +15,9 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
model_assembly = create_plugin_model_assembly(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_provider_factory = model_assembly.model_provider_factory
model_type_instance = model_assembly.create_model_type_instance(provider=provider, model_type=ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",

View File

@ -66,7 +66,7 @@ def init_code_node(code_config: dict):
node = CodeNode(
node_id=str(uuid.uuid4()),
config=CodeNodeData.model_validate(code_config["data"]),
data=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,

View File

@ -76,7 +76,7 @@ def init_http_node(config: dict):
node = HttpRequestNode(
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(config["data"]),
data=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@ -724,7 +724,7 @@ def test_nested_object_variable_selector(setup_http_mock):
node = HttpRequestNode(
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
data=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@ -77,7 +77,7 @@ def init_llm_node(config: dict) -> LLMNode:
node = LLMNode(
node_id=str(uuid.uuid4()),
config=LLMNodeData.model_validate(config["data"]),
data=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@ -71,7 +71,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
node = ParameterExtractorNode(
node_id=str(uuid.uuid4()),
config=ParameterExtractorNodeData.model_validate(config["data"]),
data=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@ -88,7 +88,7 @@ def test_execute_template_transform():
node = TemplateTransformNode(
node_id=str(uuid.uuid4()),
config=TemplateTransformNodeData.model_validate(config["data"]),
data=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),

View File

@ -62,7 +62,7 @@ def init_tool_node(config: dict):
node = ToolNode(
node_id=str(uuid.uuid4()),
config=ToolNodeData.model_validate(config["data"]),
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@ -210,7 +210,9 @@ class TestPauseStatePersistenceLayerTestContainers:
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
# Create variable pool
variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id))
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id=execution_id)
)
if variables:
for (node_id, var_key), value in variables.items():
variable_pool.add([node_id, var_key], value)

View File

@ -102,7 +102,7 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
node_id="start",
config=start_data,
data=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@ -117,7 +117,7 @@ def _build_graph(
)
human_node = HumanInputNode(
node_id="human",
config=human_data,
data=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@ -131,7 +131,7 @@ def _build_graph(
)
end_node = EndNode(
node_id="end",
config=end_data,
data=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)

View File

@ -372,7 +372,9 @@ class TestAdvancedChatGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
@ -583,7 +585,9 @@ class TestAdvancedChatGenerateTaskPipeline:
self.items = items
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
@ -617,7 +621,9 @@ class TestAdvancedChatGenerateTaskPipeline:
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"

View File

@ -58,7 +58,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
def __init__(
self,
node_id: str,
config: _StubToolNodeData,
data: _StubToolNodeData,
*,
graph_init_params,
graph_runtime_state,
@ -66,7 +66,7 @@ class _StubToolNode(Node[_StubToolNodeData]):
) -> None:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
@ -167,7 +167,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],

View File

@ -54,7 +54,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@ -93,7 +93,7 @@ class TestWorkflowBasedAppRunner:
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@ -162,7 +162,7 @@ class TestWorkflowBasedAppRunner:
app_id="app",
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
@ -241,7 +241,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
graph_runtime_state.register_paused_node("node-1")
@ -284,7 +284,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
@ -423,7 +423,7 @@ class TestWorkflowBasedAppRunner:
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
start_at=0.0,
)
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))

View File

@ -283,7 +283,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
@ -725,7 +727,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
@ -753,7 +757,9 @@ class TestWorkflowGenerateTaskPipeline:
pipeline = _make_pipeline()
pipeline._workflow_execution_id = "run-id"
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"])
@ -769,7 +775,9 @@ class TestWorkflowGenerateTaskPipeline:
def test_process_stream_response_main_match_paths_and_cleanup(self):
pipeline = _make_pipeline()
pipeline._graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-id")
),
start_at=0.0,
)
pipeline._base_task_pipeline.queue_manager.listen = lambda: iter(

View File

@ -21,7 +21,9 @@ class TestTriggerPostLayer:
)
runtime_state = SimpleNamespace(
outputs={"answer": "ok"},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=12,
)
@ -60,7 +62,9 @@ class TestTriggerPostLayer:
def test_on_event_handles_missing_trigger_log(self):
runtime_state = SimpleNamespace(
outputs={},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=0,
)
@ -91,7 +95,9 @@ class TestTriggerPostLayer:
def test_on_event_ignores_non_status_events(self):
runtime_state = SimpleNamespace(
outputs={},
variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(workflow_execution_id="run-1")
),
total_tokens=0,
)

View File

@ -60,7 +60,10 @@ def _make_layer(
workflow_execution_id="run-id",
conversation_id="conv-id",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool.from_bootstrap(system_variables=system_variables),
start_at=0.0,
)
read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state)
application_generate_entity = WorkflowAppGenerateEntity.model_construct(

View File

@ -354,7 +354,8 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
with patch(
@ -379,7 +380,10 @@ def test_validate_provider_credentials_without_credential_id() -> None:
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
validated = configuration.validate_provider_credentials(credentials={"region": "us"})
assert validated == {"region": "us"}
@ -426,23 +430,37 @@ def test_switch_preferred_provider_type_creates_record_when_missing() -> None:
def test_get_model_type_instance_and_schema_delegate_to_factory() -> None:
configuration = _build_provider_configuration()
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
mock_assembly = Mock()
mock_assembly.model_runtime = Mock()
mock_assembly.model_provider_factory = mock_factory
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory",
return_value=mock_factory,
) as mock_factory_builder:
with (
patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=mock_assembly,
) as mock_assembly_builder,
patch(
"core.entities.provider_configuration.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_model_builder,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_builder.call_count == 2
mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM)
assert mock_assembly_builder.call_count == 2
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=mock_assembly.model_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
mock_factory.get_model_schema.assert_called_once_with(
provider="openai",
model_type=ModelType.LLM,
@ -456,17 +474,21 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
bound_runtime = Mock()
configuration.bind_model_runtime(bound_runtime)
mock_factory = Mock()
mock_model_type_instance = Mock()
mock_schema = _build_ai_model("gpt-4o")
mock_factory.get_model_type_instance.return_value = mock_model_type_instance
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = configuration.provider
mock_factory.get_model_schema.return_value = mock_schema
with (
patch(
"core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory
) as mock_factory_cls,
patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder,
patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder,
patch(
"core.entities.provider_configuration.create_model_type_instance",
return_value=mock_model_type_instance,
) as mock_model_builder,
):
model_type_instance = configuration.get_model_type_instance(ModelType.LLM)
model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"})
@ -474,8 +496,14 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non
assert model_type_instance is mock_model_type_instance
assert model_schema is mock_schema
assert mock_factory_cls.call_count == 2
mock_factory_cls.assert_called_with(model_runtime=bound_runtime)
mock_factory_builder.assert_not_called()
mock_factory_cls.assert_called_with(runtime=bound_runtime)
mock_assembly_builder.assert_not_called()
mock_factory.get_provider_schema.assert_called_once_with(provider="openai")
mock_model_builder.assert_called_once_with(
runtime=bound_runtime,
provider_schema=configuration.provider,
model_type=ModelType.LLM,
)
def test_get_provider_model_returns_none_when_model_not_found() -> None:
@ -504,7 +532,10 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N
mock_factory = Mock()
mock_factory.get_provider_schema.return_value = provider_schema
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False)
active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True)
@ -722,7 +753,8 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None:
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
validated = configuration.validate_provider_credentials(
@ -1069,7 +1101,8 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
@ -1083,7 +1116,10 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless
mock_factory2 = Mock()
mock_factory2.model_credentials_validate.return_value = {"region": "us"}
with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
with patch(
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory2),
):
validated = configuration.validate_custom_model_credentials(
model_type=ModelType.LLM,
model="gpt-4o",
@ -1575,7 +1611,8 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing()
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_provider_credentials(
@ -1701,7 +1738,8 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No
with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
"core.entities.provider_configuration.create_plugin_model_assembly",
return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory),
):
with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
validated = configuration.validate_custom_model_credentials(

View File

@ -68,8 +68,8 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk")
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
assert (
check_moderation(
@ -91,7 +91,7 @@ def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture)
provider_map={openai_provider: hosting_openai},
),
)
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory")
factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_assembly")
choice_mock = mocker.patch("core.helper.moderation.secrets.choice")
assert (
@ -119,8 +119,8 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi
mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk")
moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
assert (
check_moderation(
@ -147,8 +147,8 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo
failing_model = SimpleNamespace(
invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")),
)
factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model)
mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory)
assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: failing_model)
mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly)
with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."):
check_moderation(

View File

@ -2,6 +2,7 @@ from unittest.mock import Mock
import pytest
from core.plugin.impl.model_runtime_factory import create_model_type_instance
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -73,7 +74,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None:
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
provider_schema = factory.get_model_provider("openai")
@ -98,7 +99,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
provider_schema = factory.get_model_provider("openai")
@ -107,8 +108,8 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro
def test_model_provider_factory_requires_runtime() -> None:
with pytest.raises(ValueError, match="model_runtime is required"):
ModelProviderFactory(model_runtime=None) # type: ignore[arg-type]
with pytest.raises(ValueError, match="runtime is required"):
ModelProviderFactory(runtime=None) # type: ignore[arg-type]
def test_model_provider_factory_get_providers_returns_runtime_providers() -> None:
@ -119,7 +120,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non
supported_model_types=[ModelType.LLM],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
result = factory.get_providers()
@ -133,7 +134,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider]))
factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider]))
result = factory.get_provider_schema("openai")
@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup
def test_model_provider_factory_raises_for_unknown_provider() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@ -172,7 +173,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() ->
models=[_build_model("rerank-v3", ModelType.RERANK)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai", model_type=ModelType.LLM)
@ -196,7 +197,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod
models=[_build_model("eleven_multilingual_v2", ModelType.TTS)],
),
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(model_type=ModelType.TTS)
@ -214,7 +215,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider
models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)],
)
]
factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers))
factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers))
results = factory.get_models(provider="openai")
@ -242,7 +243,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.provider_credentials_validate(
provider="openai",
@ -258,7 +259,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None:
def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@ -294,7 +295,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
)
]
)
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
filtered = factory.model_credentials_validate(
provider="openai",
@ -314,7 +315,7 @@ def test_model_provider_factory_validates_model_credentials() -> None:
def test_model_provider_factory_model_credentials_validate_requires_schema() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
@ -346,7 +347,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
)
runtime.get_model_schema.return_value = "schema"
runtime.get_provider_icon.return_value = (b"icon", "image/png")
factory = ModelProviderFactory(model_runtime=runtime)
factory = ModelProviderFactory(runtime=runtime)
assert (
factory.get_model_schema(
@ -382,39 +383,43 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider
(ModelType.TTS, TTSModel),
],
)
def test_model_provider_factory_builds_model_type_instances(
def test_create_model_type_instance_builds_model_wrappers(
model_type: ModelType,
expected_type: type[object],
) -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[model_type],
)
]
)
runtime = _FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[model_type],
)
]
)
instance = factory.get_model_type_instance("openai", model_type)
instance = create_model_type_instance(
runtime=runtime,
provider_schema=runtime.fetch_model_providers()[0],
model_type=model_type,
)
assert isinstance(instance, expected_type)
def test_model_provider_factory_rejects_unsupported_model_type() -> None:
factory = ModelProviderFactory(
model_runtime=_FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
]
)
def test_create_model_type_instance_rejects_unsupported_model_type() -> None:
runtime = _FakeModelRuntime(
[
_build_provider(
provider="langgenius/openai/openai",
provider_name="openai",
supported_model_types=[ModelType.LLM],
)
]
)
with pytest.raises(ValueError, match="Unsupported model type: unsupported"):
factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type]
create_model_type_instance(
runtime=runtime,
provider_schema=runtime.fetch_model_providers()[0],
model_type="unsupported", # type: ignore[arg-type]
)

View File

@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views():
assert assembly.model_manager is model_manager
mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1")
mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime)
mock_provider_factory_cls.assert_called_once_with(runtime=runtime)
mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime)
mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager)

View File

@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
result = manager.get_default_model("tenant-id", ModelType.LLM)
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result is not None
assert result.model == "gpt-4"
assert result.provider.provider == "openai"
@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock
result = manager.get_configurations("tenant-id")
expected_alias = str(ModelProviderID("openai"))
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
assert result.tenant_id == "tenant-id"
assert expected_alias in provider_records
assert expected_alias in provider_model_records
@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF
assert first is second
mock_get_all_providers.assert_called_once_with("tenant-id")
mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime)
mock_provider_configuration.assert_called_once()
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)

View File

@ -96,7 +96,7 @@ class MockNodeFactory(DifyNodeFactory):
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -106,7 +106,7 @@ class MockNodeFactory(DifyNodeFactory):
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -122,7 +122,7 @@ class MockNodeFactory(DifyNodeFactory):
}:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -132,7 +132,7 @@ class MockNodeFactory(DifyNodeFactory):
else:
mock_instance = mock_class(
node_id=node_id,
config=resolved_node_data,
data=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,

View File

@ -56,7 +56,7 @@ class MockNodeMixin:
def __init__(
self,
node_id: str,
config: Any,
data: Any,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
@ -98,7 +98,7 @@ class MockNodeMixin:
super().__init__(
node_id=node_id,
config=config,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
**kwargs,

View File

@ -111,7 +111,7 @@ class StaticRepo(HumanInputFormRepository):
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -140,7 +140,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
node_id=start_config["id"],
config=StartNodeData(title="Start", variables=[]),
data=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@ -155,7 +155,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
node_id=human_a_config["id"],
config=human_data,
data=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@ -165,7 +165,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
node_id=human_b_config["id"],
config=human_data,
data=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@ -183,7 +183,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
node_id=end_config["id"],
config=end_data,
data=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@ -48,7 +48,7 @@ def test_execute_answer():
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -71,7 +71,7 @@ def test_execute_answer():
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=AnswerNodeData(
data=AnswerNodeData(
title="123",
type="answer",
answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",

View File

@ -79,7 +79,7 @@ def test_datasource_node_delegates_to_manager_stream(mocker):
node = DatasourceNode(
node_id="n",
config=DatasourceNodeData(
data=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",

View File

@ -29,7 +29,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -85,7 +85,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -143,7 +143,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@ -230,7 +230,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -320,7 +320,7 @@ def test_init_headers():
node_data=node_data,
timeout=timeout,
http_request_config=HTTP_REQUEST_CONFIG,
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
http_client=ssrf_proxy,
file_manager=file_manager,
)
@ -357,7 +357,7 @@ def test_init_params():
node_data=node_data,
timeout=timeout,
http_request_config=HTTP_REQUEST_CONFIG,
variable_pool=VariablePool(system_variables=default_system_variables()),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()),
http_client=ssrf_proxy,
file_manager=file_manager,
)
@ -390,7 +390,7 @@ def test_init_params():
def test_empty_api_key_raises_error_bearer():
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer():
def test_empty_api_key_raises_error_basic():
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic():
def test_empty_api_key_raises_error_custom():
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom():
def test_whitespace_only_api_key_raises_error():
"""Test that whitespace-only API key raises AuthorizationConfigError."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error():
def test_valid_api_key_works():
"""Test that valid API key works correctly for bearer auth."""
variable_pool = VariablePool(system_variables=default_system_variables())
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables())
node_data = HttpRequestNodeData(
title="test",
method="get",
@ -536,7 +536,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable():
# UUID that triggers the json_repair truncation bug
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -583,7 +583,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
"""
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -624,7 +624,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
def test_executor_with_json_body_preserves_numbers_and_strings():
"""Test that numbers are preserved and string values are properly quoted."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)

View File

@ -110,12 +110,15 @@ def _build_http_node(
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=time.perf_counter(),
)
return HttpRequestNode(
node_id="http-node",
config=HttpRequestNodeData.model_validate(node_data),
data=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@ -149,7 +149,7 @@ def _build_human_input_node(
)
return HumanInputNode(
node_id=node_id,
config=typed_node_data,
data=typed_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=runtime,
@ -241,16 +241,16 @@ class TestUserAction:
def test_user_action_length_boundaries(self):
"""Test user action id and title length boundaries."""
action = UserAction(id="a" * 20, title="b" * 20)
action = UserAction(id="a" * 20, title="b" * 100)
assert action.id == "a" * 20
assert action.title == "b" * 20
assert action.title == "b" * 100
@pytest.mark.parametrize(
("field_name", "value"),
[
("id", "a" * 21),
("title", "b" * 21),
("title", "b" * 101),
],
)
def test_user_action_length_limits(self, field_name: str, value: str):
@ -427,7 +427,7 @@ class TestHumanInputNodeVariableResolution:
"""Tests for resolving variable-based defaults in HumanInputNode."""
def test_resolves_variable_defaults(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -504,7 +504,7 @@ class TestHumanInputNodeVariableResolution:
assert params.resolved_default_values == expected_values
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -565,7 +565,7 @@ class TestHumanInputNodeVariableResolution:
assert not hasattr(pause_event.reason, "form_token")
def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",
@ -631,7 +631,7 @@ class TestHumanInputNodeVariableResolution:
assert params.display_in_ui is True
def test_debugger_debug_mode_overrides_email_recipients(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user-123",
app_id="app",
@ -748,7 +748,7 @@ class TestHumanInputNodeRenderedContent:
"""Tests for rendering submitted content."""
def test_replaces_outputs_placeholders_after_submission(self):
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="user",
app_id="app",

View File

@ -40,7 +40,7 @@ def _create_human_input_node(
)
return HumanInputNode(
node_id=config["id"],
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
@ -51,7 +51,11 @@ def _create_human_input_node(
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
variable_pool=VariablePool.from_bootstrap(
system_variables=system_variables,
user_inputs={},
environment_variables=[],
),
start_at=0.0,
)
graph_init_params = GraphInitParams(
@ -114,7 +118,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
def _build_timeout_node() -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
variable_pool=VariablePool.from_bootstrap(
system_variables=system_variables,
user_inputs={},
environment_variables=[],
),
start_at=0.0,
)
graph_init_params = GraphInitParams(

View File

@ -32,7 +32,7 @@ class _MissingGraphBuilder:
def _build_runtime_state() -> GraphRuntimeState:
return GraphRuntimeState(
variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}),
start_at=0.0,
)
@ -46,7 +46,7 @@ def _build_iteration_node(
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
node_id="iteration-node",
config=IterationNodeData(
data=IterationNodeData(
type="iteration",
title="Iteration",
iterator_selector=["start", "items"],

View File

@ -40,7 +40,7 @@ def mock_graph_init_params():
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
@ -102,7 +102,7 @@ def _build_node(
) -> KnowledgeIndexNode:
return KnowledgeIndexNode(
node_id=node_id,
config=(
data=(
node_data
if isinstance(node_data, KnowledgeIndexNodeData)
else KnowledgeIndexNodeData.model_validate(node_data)

View File

@ -46,7 +46,7 @@ def mock_graph_init_params():
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
@ -117,7 +117,7 @@ class TestKnowledgeRetrievalNode:
# Act
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -146,7 +146,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -205,7 +205,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -249,7 +249,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -285,7 +285,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -320,7 +320,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -361,7 +361,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -400,7 +400,7 @@ class TestKnowledgeRetrievalNode:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -481,7 +481,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -518,7 +518,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -573,7 +573,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -621,7 +621,7 @@ class TestFetchDatasetRetriever:
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -682,7 +682,7 @@ class TestFetchDatasetRetriever:
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
node_id=node_id,
config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
data=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)

View File

@ -16,10 +16,10 @@ class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
@staticmethod
def _build_node(*, config, graph_init_params, graph_runtime_state):
def _build_node(*, data, graph_init_params, graph_runtime_state):
return ListOperatorNode(
node_id="test",
config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
data=data if isinstance(data, ListOperatorNodeData) else ListOperatorNodeData.model_validate(data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
@ -65,7 +65,7 @@ class TestListOperatorNode:
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
return self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -83,7 +83,7 @@ class TestListOperatorNode:
}
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -127,7 +127,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -153,7 +153,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -177,7 +177,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -201,7 +201,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -228,7 +228,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -255,7 +255,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -282,7 +282,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -312,7 +312,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -335,7 +335,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = None
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -359,7 +359,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -384,7 +384,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -408,7 +408,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -432,7 +432,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -456,7 +456,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@ -483,7 +483,7 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
node = self._build_node(
config=config,
data=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)

View File

@ -15,7 +15,7 @@ from core.app.llm.model_access import (
)
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
@ -187,7 +187,7 @@ def graph_init_params() -> GraphInitParams:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -208,7 +208,7 @@ def llm_node(
http_client = mock.MagicMock()
node = LLMNode(
node_id="1",
config=llm_node_data,
data=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@ -241,9 +241,10 @@ def model_config(monkeypatch):
)
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test"))
model_assembly = create_plugin_model_assembly(tenant_id="test")
model_provider_factory = model_assembly.model_provider_factory
provider_instance = model_provider_factory.get_model_provider("openai")
model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM)
model_type_instance = model_assembly.create_model_type_instance(provider="openai", model_type=ModelType.LLM)
# Create a ProviderModelBundle
provider_model_bundle = ProviderModelBundle(
@ -1173,7 +1174,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
http_client = mock.MagicMock()
node = LLMNode(
node_id="1",
config=llm_node_data,
data=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,

View File

@ -28,7 +28,7 @@ def _build_template_transform_node(
)
return TemplateTransformNode(
node_id=node_id,
config=typed_node_data,
data=typed_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
**kwargs,

View File

@ -39,7 +39,7 @@ def mock_graph_runtime_state():
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
node = TemplateTransformNode(
node_id="test_node",
config=TemplateTransformNodeData(
data=TemplateTransformNodeData(
title="Template Transform",
type="template-transform",
variables=[],

View File

@ -35,7 +35,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
invoke_from="debugger",
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=0.0,
)
return init_params, runtime_state
@ -62,7 +65,7 @@ def test_node_hydrates_data_during_initialization():
node = _SampleNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@ -82,13 +85,16 @@ def test_node_accepts_invoke_from_enum():
invoke_from=InvokeFrom.DEBUGGER,
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}),
variable_pool=VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="user", files=[]),
user_inputs={},
),
start_at=0.0,
)
node = _SampleNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@ -140,7 +146,7 @@ def test_node_hydration_preserves_compatibility_extra_fields():
node = _SampleNode(
node_id="node-1",
config=node_config["data"],
data=node_config["data"],
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)

View File

@ -49,7 +49,7 @@ def document_extractor_node(graph_init_params):
http_client = Mock()
node = DocumentExtractorNode(
node_id="test_node_id",
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
http_client=http_client,
@ -186,12 +186,13 @@ def test_run_extract_text(
monkeypatch.setattr("graphon.file.file_manager.download", mock_download)
dispatch_mock = None
if mime_type == "application/pdf":
mock_pdf_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
dispatch_mock = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", dispatch_mock)
elif mime_type.startswith("application/vnd.openxmlformats"):
mock_docx_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract)
dispatch_mock = Mock(return_value=expected_text[0])
monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", dispatch_mock)
result = document_extractor_node._run()
@ -200,6 +201,19 @@ def test_run_extract_text(
assert result.outputs is not None
assert result.outputs["text"] == ArrayStringSegment(value=expected_text)
if mime_type == "application/pdf":
dispatch_mock.assert_called_once_with(
file_content=file_content,
file_extension=extension,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
elif mime_type.startswith("application/vnd.openxmlformats"):
dispatch_mock.assert_called_once_with(
file_content=file_content,
mime_type=mime_type,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
if transfer_method == FileTransferMethod.REMOTE_URL:
document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
@ -439,24 +453,42 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext
file.extension = extension
file.mime_type = mime_type
with (
patch(
"graphon.nodes.document_extractor.node._download_file_content",
return_value=b"excel",
),
patch(
"graphon.nodes.document_extractor.node._extract_text_from_excel",
return_value="excel text",
) as mock_extract,
with patch(
"graphon.nodes.document_extractor.node._download_file_content",
return_value=b"excel",
):
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
if extension:
with patch(
"graphon.nodes.document_extractor.node._extract_text_by_file_extension",
return_value="excel text",
) as mock_extract:
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
mock_extract.assert_called_once_with(
file_content=b"excel",
file_extension=extension,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
else:
with patch(
"graphon.nodes.document_extractor.node._extract_text_by_mime_type",
return_value="excel text",
) as mock_extract:
result = _extract_text_from_file(
document_extractor_node.http_client,
file,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
mock_extract.assert_called_once_with(
file_content=b"excel",
mime_type=mime_type,
unstructured_api_config=document_extractor_node._unstructured_api_config,
)
assert result == "excel text"
mock_extract.assert_called_once_with(b"excel")
def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):

View File

@ -29,7 +29,7 @@ def _build_if_else_node(
node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
data=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
)
@ -48,7 +48,10 @@ def test_execute_if_else_result_true():
)
# construct variable pool
pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={})
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
)
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
@ -148,7 +151,7 @@ def test_execute_if_else_result_false():
)
# construct variable pool
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -305,7 +308,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
@ -359,7 +362,7 @@ def test_execute_if_else_boolean_false_conditions():
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)
@ -424,7 +427,7 @@ def test_execute_if_else_boolean_cases_structure():
)
# construct variable pool with boolean values
pool = VariablePool(
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(files=[], user_id="aaa"),
)
pool.add(["start", "bool_true"], True)

View File

@ -22,7 +22,7 @@ from graphon.variables import ArrayFileSegment
def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
return ListOperatorNode(
node_id="test_node_id",
config=node_data,
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=MagicMock(),
)

View File

@ -31,7 +31,7 @@ def make_start_node(user_inputs, variables):
return StartNode(
node_id="start",
config=node_data,
data=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
@ -260,7 +260,7 @@ def test_start_node_outputs_full_variable_pool_snapshot():
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = StartNode(
node_id="start",
config=node_data,
data=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},

View File

@ -99,7 +99,7 @@ def tool_node(monkeypatch) -> ToolNode:
call_depth=0,
)
variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id"))
variable_pool = VariablePool.from_bootstrap(system_variables=build_system_variables(user_id="user-id"))
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
config = graph_config["nodes"][0]
@ -110,7 +110,7 @@ def tool_node(monkeypatch) -> ToolNode:
node = ToolNode(
node_id="node-instance",
config=ToolNodeData.model_validate(config["data"]),
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@ -44,7 +44,7 @@ def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
node_id="node-1",
config=_build_node_data(),
data=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)

View File

@ -52,7 +52,7 @@ def create_webhook_node(
node = TriggerWebhookNode(
node_id="webhook-node-1",
config=webhook_data,
data=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@ -44,7 +44,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
)
node = TriggerWebhookNode(
node_id="1",
config=webhook_data,
data=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)

View File

@ -55,7 +55,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_to_variable_pool_with_system_variables(self):
"""Test mapping system variables from user inputs to variable pool."""
# Initialize variable pool with system variables
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="test_user_id",
app_id="test_app_id",
@ -128,7 +128,7 @@ class TestWorkflowEntry:
return NodeConfigDictAdapter.validate_python(node_config)
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
expected_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
@ -157,7 +157,7 @@ class TestWorkflowEntry:
"""Test mapping environment variables from user inputs to variable pool."""
# Initialize variable pool with environment variables
env_var = StringVariable(name="API_KEY", value="existing_key")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
environment_variables=[env_var],
user_inputs={},
@ -198,7 +198,7 @@ class TestWorkflowEntry:
"""Test mapping conversation variables from user inputs to variable pool."""
# Initialize variable pool with conversation variables
conv_var = StringVariable(name="last_message", value="Hello")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
conversation_variables=[conv_var],
user_inputs={},
@ -239,7 +239,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self):
"""Test mapping regular node variables from user inputs to variable pool."""
# Initialize empty variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -281,7 +281,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_file_handling(self):
"""Test mapping file inputs from user inputs to variable pool."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -340,7 +340,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_missing_variable_error(self):
"""Test that mapping raises error when required variable is missing."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -366,7 +366,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_alternative_key_format(self):
"""Test mapping with alternative key format (without node prefix)."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -396,7 +396,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_with_complex_selectors(self):
"""Test mapping with complex node variable keys."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -432,7 +432,7 @@ class TestWorkflowEntry:
def test_mapping_user_inputs_invalid_node_variable(self):
"""Test that mapping handles invalid node variable format."""
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
)
@ -463,7 +463,7 @@ class TestWorkflowEntry:
env_var = StringVariable(name="API_KEY", value="existing_key")
conv_var = StringVariable(name="session_id", value="session123")
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="test_user",
app_id="test_app",

View File

@ -165,7 +165,7 @@ class TestWorkflowChildEngineBuilder:
_ = graph_config
node = node_cls(
node_id=root_node_id,
config=BaseNodeData(
data=BaseNodeData(
type=node_cls.node_type,
title="Child Model",
),
@ -334,7 +334,7 @@ class TestWorkflowEntrySingleStepRun:
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={})
variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={})
variable_loader = MagicMock()
variable_loader.load_variables.return_value = [
StringVariable(

8
api/uv.lock generated
View File

@ -1594,7 +1594,7 @@ requires-dist = [
{ name = "gmpy2", specifier = ">=2.3.0" },
{ name = "google-api-python-client", specifier = ">=2.195.0" },
{ name = "google-cloud-aiplatform", specifier = ">=1.149.0,<2.0.0" },
{ name = "graphon", specifier = "~=0.2.2" },
{ name = "graphon", specifier = "~=0.3.0" },
{ name = "gunicorn", specifier = ">=25.3.0" },
{ name = "httpx", extras = ["socks"], specifier = ">=0.28.1,<1.0.0" },
{ name = "httpx-sse", specifier = "~=0.4.0" },
@ -2937,7 +2937,7 @@ httpx = [
[[package]]
name = "graphon"
version = "0.2.2"
version = "0.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "charset-normalizer" },
@ -2958,9 +2958,9 @@ dependencies = [
{ name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] },
{ name = "webvtt-py" },
]
sdist = { url = "https://files.pythonhosted.org/packages/08/50/e745a79c5f742f88f6011a1f7c9ba2c2f9cc1beedd982f0b192f1ab8c748/graphon-0.2.2.tar.gz", hash = "sha256:141f0de536171850f1af6f738dc66f0285aadd3c097f1dad2a038636789e0aa5", size = 236360, upload-time = "2026-04-17T08:52:28.047Z" }
sdist = { url = "https://files.pythonhosted.org/packages/bf/62/83593d6e7a139ff124711ea05882cadca7065c11a38763aa9360d7e76804/graphon-0.3.0.tar.gz", hash = "sha256:cd38f842ae3dcfa956428b952efbe2a3ea9c1581446647142accbbdeb638b876", size = 241176, upload-time = "2026-04-21T15:18:48.291Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/de/89/a6340afdaf5169d17a318e00fc685fb67ed99baa602c2cbbbf6af6a76096/graphon-0.2.2-py3-none-any.whl", hash = "sha256:754e544d08779138f99eac6547ab08559463680e2c76488b05e1c978210392b4", size = 340808, upload-time = "2026-04-17T08:52:26.5Z" },
{ url = "https://files.pythonhosted.org/packages/b3/f7/81ee8f0368aa6a2d47f97fecc5d4a12865c987906798cbddd0e3b8387f33/graphon-0.3.0-py3-none-any.whl", hash = "sha256:9cca45ebab2a79fd4d04432f55b5b962e9e4f34fa037cc20fee7f18ec80eaa5d", size = 348486, upload-time = "2026-04-21T15:18:46.737Z" },
]
[[package]]