fix(api): resolve graphon upgrade CI failures

Repair the plugin model runtime for Graphon 0.3.0 by implementing structured-output support, normalizing non-stream LLM calls, and tightening workflow-layer type safety.

Update the remaining workflow tests to use VariablePool.from_bootstrap(...) and the new node data constructor API so the CI unit and integration suites match the upgraded runtime behavior.
This commit is contained in:
-LAN- 2026-04-22 00:22:38 +08:00
parent 05aba26091
commit f13faa69b3
16 changed files with 222 additions and 69 deletions

View File

@ -5,7 +5,7 @@ This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final, override
from typing import final, override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
@ -17,11 +17,6 @@ from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.nodes.base.node import Node
if TYPE_CHECKING:
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@ -109,16 +104,14 @@ class LLMQuotaLayer(GraphEngineLayer):
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
match node.node_type:
case BuiltinNodeTypes.LLM | BuiltinNodeTypes.PARAMETER_EXTRACTOR | BuiltinNodeTypes.QUESTION_CLASSIFIER:
pass
case _:
return None
try:
match node.node_type:
case BuiltinNodeTypes.LLM:
model_instance = cast("LLMNode", node).model_instance
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
model_instance = cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
model_instance = cast("QuestionClassifierNode", node).model_instance
case _:
return None
model_instance = getattr(node, "model_instance", None)
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
@ -133,4 +126,12 @@ class LLMQuotaLayer(GraphEngineLayer):
if isinstance(raw_model_instance, ModelInstance):
return raw_model_instance
private_model_instance = getattr(node, "_model_instance", None)
if isinstance(private_model_instance, ModelInstance):
return private_model_instance
wrapped_private_model_instance = getattr(private_model_instance, "_model_instance", None)
if isinstance(wrapped_private_model_instance, ModelInstance):
return wrapped_private_model_instance
return None

View File

@ -4,22 +4,31 @@ import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Union
from typing import IO, Any, Literal, cast, overload
from pydantic import ValidationError
from redis import RedisError
from configs import dify_config
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
)
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
from graphon.model_runtime.protocols.runtime import ModelRuntime
from models.provider_ids import ModelProviderID
@ -29,6 +38,57 @@ logger = logging.getLogger(__name__)
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
class _PluginStructuredOutputModelInstance:
"""Reuse the shared structured-output helper without depending on `ModelInstance`."""
def __init__(
self,
*,
runtime: PluginModelRuntime,
provider: str,
model: str,
credentials: dict[str, Any],
) -> None:
self._runtime = runtime
self._provider = provider
self._model = model
self._credentials = credentials
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
callbacks: object | None = None,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
del callbacks
if stream:
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=True,
)
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=False,
)
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
@ -195,6 +255,34 @@ class PluginModelRuntime(ModelRuntime):
return schema
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@ -206,9 +294,9 @@ class PluginModelRuntime(ModelRuntime):
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
) -> LLMResult | Generator[LLMResultChunk, None, None]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_llm(
result = self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
@ -221,6 +309,81 @@ class PluginModelRuntime(ModelRuntime):
stop=list(stop) if stop else None,
stream=stream,
)
if stream:
return result
return normalize_non_stream_runtime_result(
model=model,
prompt_messages=prompt_messages,
result=result,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
model_schema = self.get_model_schema(
provider=provider,
model_type=ModelType.LLM,
model=model,
credentials=credentials,
)
if model_schema is None:
raise ValueError(f"Model schema not found for {model}")
adapter = _PluginStructuredOutputModelInstance(
runtime=self,
provider=provider,
model=model,
credentials=credentials,
)
return invoke_llm_with_structured_output_helper(
provider=provider,
model_schema=model_schema,
model_instance=cast(Any, adapter),
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
tools=None,
stop=list(stop) if stop else None,
stream=stream,
)
def get_llm_num_tokens(
self,

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Protocol, cast
from typing import Any, Protocol
from uuid import uuid4
from graphon.enums import BuiltinNodeTypes
@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
normalized = _normalize_system_variable_values(values, **kwargs)
return [
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
),
segment_to_variable(
segment=build_segment(value),
selector=system_variable_selector(key),
name=key,
)
for key, value in normalized.items()
]
@ -130,13 +127,10 @@ def build_bootstrap_variables(
for node_id, value in rag_pipeline_variables_map.items():
variables.append(
cast(
Variable,
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
),
segment_to_variable(
segment=build_segment(value),
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
name=node_id,
)
)

View File

@ -157,8 +157,8 @@ class DraftVarLoader(VariableLoader):
# This approach reduces loading time by querying external systems concurrently.
with ThreadPoolExecutor(max_workers=10) as executor:
offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars)
for selector, variable in offloaded_variables:
variable_by_selector[selector] = variable
for selector, offloaded_variable in offloaded_variables:
variable_by_selector[selector] = offloaded_variable
return list(variable_by_selector.values())

View File

@ -45,7 +45,7 @@ def init_code_node(code_config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],

View File

@ -55,7 +55,7 @@ def init_http_node(config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
from graphon.runtime import VariablePool
# Create variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="test", files=[]),
user_inputs={},
environment_variables=[],
@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock):
)
# Create independent variable pool for this test only
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],

View File

@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode:
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="aaa",
app_id=app_id,

View File

@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
),

View File

@ -66,7 +66,7 @@ def test_execute_template_transform():
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],

View File

@ -41,7 +41,7 @@ def init_tool_node(config: dict):
)
# construct variable pool
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],

View File

@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(
workflow_execution_id=workflow_execution_id,
app_id=app_id,

View File

@ -16,7 +16,7 @@ from models.workflow import Workflow
def _make_graph_state():
variable_pool = VariablePool(
variable_pool = VariablePool.from_bootstrap(
system_variables=default_system_variables(),
user_inputs={},
environment_variables=[],

View File

@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs):
self.id = node_id
self.config = config
self.data = data
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self.kwargs = kwargs

View File

@ -14,16 +14,11 @@ from graphon.nodes.llm.entities import LLMNodeData
from graphon.variables.segments import StringSegment
def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
def _assert_typed_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None:
_ = node_id
if isinstance(config, BaseNodeData):
assert config.type == node_type
assert config.version == version
return
assert isinstance(config, dict)
assert config["type"] == node_type
assert config["version"] == version
assert isinstance(data, BaseNodeData)
assert data.type == node_type
assert data.version == version
def _node_constructor(*, return_value):
@ -470,7 +465,7 @@ class TestDifyNodeFactoryCreateNode:
matched_node_class.assert_called_once()
kwargs = matched_node_class.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
_assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
latest_node_class.assert_not_called()
@ -490,7 +485,7 @@ class TestDifyNodeFactoryCreateNode:
latest_node_class.assert_called_once()
kwargs = latest_node_class.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
_assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9")
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
@ -528,7 +523,7 @@ class TestDifyNodeFactoryCreateNode:
assert result is created_node
kwargs = constructor.call_args.kwargs
assert kwargs["node_id"] == "node-id"
_assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type)
_assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=node_type)
assert kwargs["graph_init_params"] is sentinel.graph_init_params
assert kwargs["graph_runtime_state"] is factory.graph_runtime_state
@ -597,7 +592,7 @@ class TestDifyNodeFactoryCreateNode:
prepared_llm.assert_called_once_with(sentinel.model_instance)
assert kwargs["model_instance"] is wrapped_model_instance
def test_create_node_passes_alias_preserving_llm_config_to_constructor(self, monkeypatch, factory):
def test_create_node_passes_typed_llm_data_to_constructor(self, monkeypatch, factory):
created_node = object()
constructor = _node_constructor(return_value=created_node)
monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor))
@ -625,10 +620,10 @@ class TestDifyNodeFactoryCreateNode:
factory.create_node(node_config)
config = constructor.call_args.kwargs["config"]
assert isinstance(config, dict)
assert config["structured_output_enabled"] is True
assert "structured_output_switch_on" not in config
data = constructor.call_args.kwargs["data"]
assert isinstance(data, LLMNodeData)
assert data.structured_output_enabled is True
assert data.structured_output_switch_on is True
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
@ -707,7 +702,7 @@ class TestDifyNodeFactoryCreateNode:
constructor_kwargs = constructor.call_args.kwargs
assert constructor_kwargs["node_id"] == "node-id"
_assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type)
_assert_typed_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type)
assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params
assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state
assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider

View File

@ -109,8 +109,8 @@ class TestVariablePool:
assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None
assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None
def test_constructor_loads_legacy_bootstrap_kwargs(self):
pool = VariablePool(
def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self):
pool = VariablePool.from_bootstrap(
system_variables=build_system_variables(user_id="test_user_id"),
environment_variables=[StringVariable(name="env_var", value="env-value")],
conversation_variables=[StringVariable(name="conv_var", value="conv-value")],

View File

@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution:
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
mock_node_cls.assert_called_once_with(
node_id="n-1",
config=sentinel.node_data,
data=sentinel.node_data,
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
graph_runtime_state=ANY,
runtime=mock_runtime_cls.return_value,