feat(api): LLM polling support (#37462)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
QuantumGhost 2026-06-18 07:34:33 +08:00 committed by GitHub
parent 19838972dc
commit f0b34bdeb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 704 additions and 46 deletions

View File

@ -20,6 +20,7 @@ from core.plugin.impl.exc import (
PluginDaemonNotFoundError,
PluginDaemonUnauthorizedError,
PluginInvokeError,
PluginLLMPollingUnsupportedError,
PluginNotFoundError,
PluginPermissionDeniedError,
PluginUniqueIdentifierError,
@ -370,6 +371,10 @@ class BasePluginClient:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
raise EventIgnoreError(description=error_object.get("message"))
# NOTE: current plugin sdk / plugin daemon does not raise exception with
# type `PluginLLMPollingUnsupportedError`.
case PluginLLMPollingUnsupportedError.__name__:
raise PluginLLMPollingUnsupportedError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:

View File

@ -5,6 +5,13 @@ from pydantic import TypeAdapter
from extensions.ext_logging import get_request_id
# NOTE: Avoid renaming exception classes in this file, since
# the `_handle_plugin_daemon_error` in api/core/plugin/impl/base.py
# build exception instances based on the class name.
#
# Renaming of exception classes could result in incorrect exception
# being raised.
class PluginDaemonError(Exception):
"""Base class for all plugin daemon errors."""
@ -75,6 +82,10 @@ class PluginInvokeError(PluginDaemonClientSideError, ValueError):
)
class PluginLLMPollingUnsupportedError(PluginInvokeError):
"""Plugin-backed LLM polling is unavailable for the requested model."""
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
description: str = "Unique Identifier Error"

View File

@ -13,13 +13,17 @@ from core.plugin.entities.plugin_daemon import (
PluginVoicesResponse,
)
from core.plugin.impl.base import BasePluginClient
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
from core.plugin.impl.exc import PluginInvokeError, PluginLLMPollingUnsupportedError
from graphon.model_runtime.entities.llm_entities import LLMPollingResult, LLMResultChunk
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.utils.encoders import jsonable_encoder
_POLLING_UNSUPPORTED_INVOKE_ERROR_TYPES = frozenset((NotImplementedError.__name__,))
_POLLING_UNSUPPORTED_ERROR_MESSAGE = "does not support polling"
class PluginModelClient(BasePluginClient):
@staticmethod
@ -197,6 +201,103 @@ class PluginModelClient(BasePluginClient):
except PluginDaemonInnerError as e:
raise ValueError(e.message + str(e.code))
def start_llm_polling(
self,
tenant_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
json_schema: dict[str, Any] | None = None,
) -> LLMPollingResult:
"""Start an LLM polling request for plugin-backed long-running jobs."""
try:
return self._request_with_plugin_daemon_response(
method="POST",
path=f"plugin/{tenant_id}/dispatch/model/polling/start",
type_=LLMPollingResult,
data=jsonable_encoder(
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": ModelType.LLM.value,
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": False,
"json_schema": json_schema,
},
)
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
except PluginInvokeError as error:
self._raise_typed_polling_unsupported_error(error)
raise
def check_llm_polling(
self,
tenant_id: str,
user_id: str | None,
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
plugin_state: dict[str, Any],
) -> LLMPollingResult:
"""Check the latest state for a plugin-backed LLM polling job."""
try:
return self._request_with_plugin_daemon_response(
method="POST",
path=f"plugin/{tenant_id}/dispatch/model/polling/check",
type_=LLMPollingResult,
data=jsonable_encoder(
self._dispatch_payload(
user_id=user_id,
data={
"provider": provider,
"model_type": ModelType.LLM.value,
"model": model,
"credentials": credentials,
"plugin_state": plugin_state,
},
)
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
except PluginInvokeError as error:
self._raise_typed_polling_unsupported_error(error)
raise
@staticmethod
def _raise_typed_polling_unsupported_error(error: PluginInvokeError) -> None:
"""Convert plugin polling capability failures into a dedicated Dify exception."""
if error.get_error_type() == PluginLLMPollingUnsupportedError.__name__:
raise PluginLLMPollingUnsupportedError(description=error.description) from error
if (
error.get_error_type() in _POLLING_UNSUPPORTED_INVOKE_ERROR_TYPES
# This is ugly, we should not rely on error messages while checking
# error types.
and _POLLING_UNSUPPORTED_ERROR_MESSAGE in error.get_error_message().lower()
):
raise PluginLLMPollingUnsupportedError(description=error.description) from error
def get_llm_num_tokens(
self,
tenant_id: str,

View File

@ -6,6 +6,7 @@ from collections.abc import Generator, Iterable, Sequence
from typing import IO, Any, Literal, cast, overload, override
from pydantic import ValidationError
from pydantic.json_schema import JsonValue
from redis import RedisError
from configs import dify_config
@ -17,6 +18,7 @@ from core.plugin.impl.model import PluginModelClient
from core.plugin.plugin_service import PluginService
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import (
LLMPollingResult,
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
@ -430,6 +432,54 @@ class PluginModelRuntime(ModelRuntime):
tools=list(tools) if tools else None,
)
def start_llm_polling(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
json_schema: dict[str, Any] | None,
) -> LLMPollingResult:
"""Start a plugin-side polling job for long-running LLM invocations."""
plugin_id, provider_name = self._split_provider(provider)
return self.client.start_llm_polling(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
tools=list(tools) if tools else None,
stop=list(stop) if stop else None,
json_schema=json_schema,
)
def check_llm_polling(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
plugin_state: dict[str, JsonValue],
) -> LLMPollingResult:
"""Check the latest plugin-side polling state for an LLM invocation."""
plugin_id, provider_name = self._split_provider(provider)
return self.client.check_llm_polling(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
plugin_state=plugin_state,
)
@override
def invoke_text_embedding(
self,

View File

@ -26,6 +26,7 @@ from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
DifyPreparedLLM,
DifyPreparedPollingLLM,
DifyPromptMessageSerializer,
DifyRetrieverAttachmentLoader,
DifyToolFileManager,
@ -531,7 +532,11 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,
"model_factory": self._llm_model_factory,
"model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance,
"model_instance": (
self._wrap_model_instance_for_node(node_data=validated_node_data, model_instance=model_instance)
if wrap_model_instance
else model_instance
),
"memory": self._build_memory_for_llm_node(
node_data=validated_node_data,
model_instance=model_instance,
@ -555,6 +560,23 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY)
return node_init_kwargs
@staticmethod
def _wrap_model_instance_for_node(
*,
node_data: LLMCompatibleNodeData,
model_instance: ModelInstance,
) -> DifyPreparedLLM:
# Only graphon's LLM node consumes the polling protocol. Keep classifier
# and extractor nodes on the existing wrapper even if the same model
# advertises polling support.
if node_data.type == BuiltinNodeTypes.LLM and DifyNodeFactory._supports_plugin_llm_polling(model_instance):
return DifyPreparedPollingLLM(model_instance)
return DifyPreparedLLM(model_instance)
@staticmethod
def _supports_plugin_llm_polling(model_instance: ModelInstance) -> bool:
return model_instance.get_model_schema().support_polling
def _build_retriever_attachment_loader(self, node_data: LLMNodeData) -> DifyRetrieverAttachmentLoader:
return DifyRetrieverAttachmentLoader(
file_reference_factory=self._file_reference_factory,

View File

@ -4,6 +4,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast, overload, override
from pydantic import JsonValue
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -38,6 +39,7 @@ from factories import file_factory
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities import LLMMode
from graphon.model_runtime.entities.llm_entities import (
LLMPollingResult,
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
@ -54,6 +56,7 @@ from graphon.nodes.human_input.entities import (
HumanInputNodeData,
)
from graphon.nodes.llm.runtime_protocols import (
LLMPollingCapableProtocol,
LLMProtocol,
PromptMessageSerializerProtocol,
RetrieverAttachmentLoaderProtocol,
@ -278,6 +281,58 @@ class DifyPreparedLLM(LLMProtocol):
return isinstance(error, OutputParserError)
class DifyPreparedPollingLLM(DifyPreparedLLM, LLMPollingCapableProtocol):
"""Prepared workflow LLM adapter that exposes Graphon's polling protocol."""
def __init__(self, model_instance: ModelInstance) -> None:
from core.plugin.impl.model_runtime import PluginModelRuntime
super().__init__(model_instance)
model_type_instance = model_instance.model_type_instance
if not isinstance(model_type_instance, LargeLanguageModel):
raise TypeError("Polling wrapper requires a large-language-model instance.")
plugin_model_runtime = model_type_instance.model_runtime
if not isinstance(plugin_model_runtime, PluginModelRuntime):
raise TypeError("Polling wrapper requires a plugin-backed model runtime.")
self._plugin_model_runtime = plugin_model_runtime
@override
def start_llm_polling(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
json_schema: Mapping[str, Any] | None,
) -> LLMPollingResult:
return self._plugin_model_runtime.start_llm_polling(
provider=self.provider,
model=self.model_name,
credentials=self._model_instance.credentials,
prompt_messages=prompt_messages,
model_parameters=dict(model_parameters),
tools=tools,
stop=stop,
json_schema=dict(json_schema) if json_schema is not None else None,
)
@override
def check_llm_polling(
self,
*,
plugin_state: Mapping[str, JsonValue],
) -> LLMPollingResult:
return self._plugin_model_runtime.check_llm_polling(
provider=self.provider,
model=self.model_name,
credentials=self._model_instance.credentials,
plugin_state=dict(plugin_state),
)
class DifyPromptMessageSerializer(PromptMessageSerializerProtocol):
@override
def serialize(

View File

@ -44,7 +44,7 @@ dependencies = [
"resend>=2.27.0,<3.0.0",
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]==0.7.0",
"graphon==0.5.1",
"graphon==0.5.2",
"httpx-sse==0.4.3",
"json-repair==0.59.4",
]

View File

@ -7,6 +7,7 @@ from pytest_mock import MockerFixture
from core.plugin.endpoint.exc import EndpointSetupFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
from core.plugin.impl.base import PLUGIN_DAEMON_MAX_PATH_LENGTH, BasePluginClient
from core.plugin.impl.exc import PluginLLMPollingUnsupportedError
from core.trigger.errors import (
EventIgnoreError,
TriggerInvokeError,
@ -167,3 +168,10 @@ class TestBasePluginClientImpl:
with pytest.raises(expected):
client._handle_plugin_daemon_error("PluginInvokeError", message)
def test_handle_plugin_daemon_error_maps_unsupported_polling_to_typed_exception(self):
client = BasePluginClient()
message = json.dumps({"error_type": PluginLLMPollingUnsupportedError.__name__, "message": "m"})
with pytest.raises(PluginLLMPollingUnsupportedError):
client._handle_plugin_daemon_error("PluginInvokeError", message)

View File

@ -1,13 +1,17 @@
from __future__ import annotations
import io
import json
from types import SimpleNamespace
import pytest
from pytest_mock import MockerFixture
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
from core.plugin.impl.exc import PluginInvokeError, PluginLLMPollingUnsupportedError
from core.plugin.impl.model import PluginModelClient
from graphon.model_runtime.entities.llm_entities import LLMPollingResult, LLMPollingStatus, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
class TestPluginModelClient:
@ -183,6 +187,113 @@ class TestPluginModelClient:
)
)
def test_start_llm_polling(self, mocker: MockerFixture):
client = PluginModelClient()
polling_result = LLMPollingResult(
status=LLMPollingStatus.RUNNING,
plugin_state={"task_id": "poll-1"},
next_check_after_seconds=3,
)
request_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response",
return_value=polling_result,
)
result = client.start_llm_polling(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="gpt-test",
credentials={"api_key": "key"},
prompt_messages=[],
model_parameters={"temperature": 0.1},
tools=[],
stop=["STOP"],
json_schema={"type": "object"},
)
assert result == polling_result
call_kwargs = request_mock.call_args.kwargs
assert call_kwargs["path"] == "plugin/tenant-1/dispatch/model/polling/start"
assert call_kwargs["data"]["data"] == {
"provider": "provider-a",
"model_type": "llm",
"model": "gpt-test",
"credentials": {"api_key": "key"},
"prompt_messages": [],
"model_parameters": {"temperature": 0.1},
"tools": [],
"stop": ["STOP"],
"stream": False,
"json_schema": {"type": "object"},
}
def test_check_llm_polling(self, mocker: MockerFixture):
client = PluginModelClient()
polling_result = LLMPollingResult(
status=LLMPollingStatus.SUCCEEDED,
result=LLMResult(
model="gpt-test",
prompt_messages=[],
message=AssistantPromptMessage(content="done"),
usage=LLMUsage.empty_usage(),
),
)
request_mock = mocker.patch.object(
client,
"_request_with_plugin_daemon_response",
return_value=polling_result,
)
result = client.check_llm_polling(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="gpt-test",
credentials={"api_key": "key"},
plugin_state={"task_id": "poll-1"},
)
assert result == polling_result
call_kwargs = request_mock.call_args.kwargs
assert call_kwargs["path"] == "plugin/tenant-1/dispatch/model/polling/check"
assert call_kwargs["data"]["data"] == {
"provider": "provider-a",
"model_type": "llm",
"model": "gpt-test",
"credentials": {"api_key": "key"},
"plugin_state": {"task_id": "poll-1"},
}
def test_start_llm_polling_maps_unsupported_polling_invoke_error(self, mocker: MockerFixture):
client = PluginModelClient()
mocker.patch.object(
client,
"_request_with_plugin_daemon_response",
side_effect=PluginInvokeError(
json.dumps(
{
"error_type": PluginLLMPollingUnsupportedError.__name__,
"message": "Model `gpt-test` does not support polling.",
}
)
),
)
with pytest.raises(PluginLLMPollingUnsupportedError):
client.start_llm_polling(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="gpt-test",
credentials={"api_key": "key"},
prompt_messages=[],
)
def test_get_llm_num_tokens(self, mocker: MockerFixture):
client = PluginModelClient()
mocker.patch.object(

View File

@ -14,7 +14,13 @@ from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, Pl
from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime
from core.plugin.plugin_service import PluginService
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.llm_entities import (
LLMPollingResult,
LLMPollingStatus,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
)
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
@ -282,6 +288,74 @@ class TestPluginModelRuntime:
stream=True,
)
def test_start_llm_polling_resolves_plugin_fields(self) -> None:
client = Mock(spec=PluginModelClient)
polling_result = LLMPollingResult(
status=LLMPollingStatus.RUNNING,
plugin_state={"task_id": "poll-1"},
next_check_after_seconds=2,
)
client.start_llm_polling.return_value = polling_result
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
result = runtime.start_llm_polling(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
model_parameters={"temperature": 0.2},
prompt_messages=[],
tools=None,
stop=("END",),
json_schema={"type": "object"},
)
assert result == polling_result
client.start_llm_polling.assert_called_once_with(
tenant_id="tenant",
user_id="user",
plugin_id="langgenius/openai",
provider="openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
prompt_messages=[],
model_parameters={"temperature": 0.2},
tools=None,
stop=["END"],
json_schema={"type": "object"},
)
def test_check_llm_polling_resolves_plugin_fields(self) -> None:
client = Mock(spec=PluginModelClient)
polling_result = LLMPollingResult(
status=LLMPollingStatus.SUCCEEDED,
result=model_runtime_module.LLMResult(
model="gpt-4o-mini",
prompt_messages=[],
message=AssistantPromptMessage(content="done"),
usage=LLMUsage.empty_usage(),
),
)
client.check_llm_polling.return_value = polling_result
runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client, plugin_service=PluginService)
result = runtime.check_llm_polling(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
plugin_state={"task_id": "poll-1"},
)
assert result == polling_result
client.check_llm_polling.assert_called_once_with(
tenant_id="tenant",
user_id="user",
plugin_id="langgenius/openai",
provider="openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
plugin_state={"task_id": "poll-1"},
)
def test_invoke_llm_rejects_per_call_user_override(self) -> None:
client = Mock(spec=PluginModelClient)
client.invoke_llm.return_value = sentinel.result

View File

@ -703,9 +703,12 @@ class TestSchemaResolverClass:
# For schemas without refs, hybrid should be competitive or better
if not expected: # No refs case
# Hybrid might be slightly slower due to JSON serialization overhead,
# but should not be dramatically worse
assert avg_hybrid < avg_recursive * 5 # At most 5x slower
relative_slowdown_limit = 5.0
absolute_noise_budget_seconds = 2e-4
# JSON serialization has a fixed overhead that dominates tiny schemas,
# so allow a small absolute noise budget on top of the relative limit.
assert avg_hybrid < (avg_recursive * relative_slowdown_limit) + absolute_noise_budget_seconds
def test_string_matching_edge_cases(self):
"""Test edge cases for string-based detection"""

View File

@ -1,18 +1,26 @@
from collections.abc import Mapping
from types import SimpleNamespace
from unittest.mock import MagicMock, patch, sentinel
from unittest.mock import MagicMock, Mock, patch, sentinel
import pytest
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom
from core.plugin.impl.model import PluginModelClient
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.plugin.plugin_service import PluginService
from core.workflow import node_factory
from core.workflow import template_rendering as workflow_template_rendering
from core.workflow.node_runtime import DifyPreparedLLM
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.code.entities import CodeLanguage
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.runtime_protocols import LLMPollingCapableProtocol
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.variables.segments import ArrayObjectSegment, StringSegment
@ -35,6 +43,41 @@ def _node_constructor(*, return_value):
return constructor
def _build_llm_model_schema(*, features: list[ModelFeature] | None = None) -> AIModelEntity:
return AIModelEntity(
model="model",
label=I18nObject(en_US="Model"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
features=features,
)
class _ModelTypeInstanceStub(LargeLanguageModel):
def __init__(self, *, model_runtime: object) -> None:
self.model_runtime = model_runtime
class _ModelInstanceStub:
def __init__(
self,
*,
model_runtime: object,
model_schema: AIModelEntity,
) -> None:
self.provider = "langgenius/openai/openai"
self.model_name = "model"
self.credentials = {"api_key": "secret"}
self.parameters = {}
self.stop = ()
self.model_type_instance = _ModelTypeInstanceStub(model_runtime=model_runtime)
self._model_schema = model_schema
def get_model_schema(self) -> AIModelEntity:
return self._model_schema
class TestResolveWorkflowNodeClass:
def test_matching_version_uses_registry_mapping(self, monkeypatch) -> None:
document_extractor_class = sentinel.document_extractor_class
@ -667,7 +710,7 @@ class TestDifyNodeFactoryCreateNode:
memory = sentinel.memory
factory._build_model_instance_for_llm_node = MagicMock(return_value=sentinel.model_instance)
factory._build_memory_for_llm_node = MagicMock(return_value=memory)
with patch.object(node_factory, "DifyPreparedLLM", return_value=wrapped_model_instance) as prepared_llm:
with patch.object(factory, "_wrap_model_instance_for_node", return_value=wrapped_model_instance) as wrap_model:
kwargs = factory._build_llm_compatible_node_init_kwargs(
node_class=sentinel.node_class,
node_data=node_data,
@ -686,9 +729,70 @@ class TestDifyNodeFactoryCreateNode:
node_data=node_data,
model_instance=sentinel.model_instance,
)
prepared_llm.assert_called_once_with(sentinel.model_instance)
wrap_model.assert_called_once_with(
node_data=node_data,
model_instance=sentinel.model_instance,
)
assert kwargs["model_instance"] is wrapped_model_instance
def test_build_llm_compatible_node_init_kwargs_uses_polling_wrapper_for_polling_llm_node(self, factory):
node_data = LLMNodeData.model_validate(
{
"type": BuiltinNodeTypes.LLM,
"title": "LLM",
"model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}},
"prompt_template": [{"role": "system", "text": "x"}],
"context": {"enabled": False, "variable_selector": []},
"vision": {"enabled": False},
}
)
plugin_runtime = PluginModelRuntime(
tenant_id="tenant-id",
user_id="user-id",
client=Mock(spec=PluginModelClient),
plugin_service=PluginService,
)
model_instance = _ModelInstanceStub(
model_runtime=plugin_runtime,
model_schema=_build_llm_model_schema(features=[ModelFeature.POLLING]),
)
factory._build_model_instance_for_llm_node = MagicMock(return_value=model_instance)
factory._build_memory_for_llm_node = MagicMock(return_value=sentinel.memory)
kwargs = factory._build_llm_compatible_node_init_kwargs(
node_class=sentinel.node_class,
node_data=node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
include_prompt_message_serializer=False,
include_retriever_attachment_loader=False,
include_jinja2_template_renderer=False,
)
assert isinstance(kwargs["model_instance"], LLMPollingCapableProtocol)
@pytest.mark.parametrize("node_type", [BuiltinNodeTypes.QUESTION_CLASSIFIER, BuiltinNodeTypes.PARAMETER_EXTRACTOR])
def test_wrap_model_instance_keeps_non_llm_graph_nodes_on_plain_wrapper(self, node_type):
plugin_runtime = PluginModelRuntime(
tenant_id="tenant-id",
user_id="user-id",
client=Mock(spec=PluginModelClient),
plugin_service=PluginService,
)
model_instance = _ModelInstanceStub(
model_runtime=plugin_runtime,
model_schema=_build_llm_model_schema(features=[ModelFeature.POLLING]),
)
wrapped = node_factory.DifyNodeFactory._wrap_model_instance_for_node(
node_data=SimpleNamespace(type=node_type),
model_instance=model_instance,
)
assert type(wrapped) is DifyPreparedLLM
assert not isinstance(wrapped, LLMPollingCapableProtocol)
def test_create_node_passes_alias_preserving_llm_data_to_constructor(self, monkeypatch, factory):
created_node = object()
constructor = _node_constructor(return_value=created_node)

View File

@ -7,6 +7,10 @@ import pytest
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom
from core.app.file_access import FileAccessScope, bind_file_access_scope, grant_retriever_segment_access
from core.llm_generator.output_parser.errors import OutputParserError
from core.plugin.impl.exc import PluginLLMPollingUnsupportedError
from core.plugin.impl.model import PluginModelClient
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.plugin.plugin_service import PluginService
from core.workflow import node_runtime
from core.workflow.file_reference import parse_file_reference
from core.workflow.human_input_adapter import (
@ -21,6 +25,7 @@ from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
DifyPreparedLLM,
DifyPreparedPollingLLM,
DifyPromptMessageSerializer,
DifyRetrieverAttachmentLoader,
DifyToolFileManager,
@ -31,23 +36,61 @@ from core.workflow.node_runtime import (
)
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.common_entities import I18nObject
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.llm_entities import LLMPollingResult, LLMPollingStatus
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import FileInputConfig, FileListInputConfig, HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import LLMPollingCapableProtocol
from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType
from graphon.variables.segments import ArrayFileSegment, FileSegment
from tests.workflow_test_utils import build_test_run_context
def _build_model_schema() -> AIModelEntity:
def _build_model_schema(*, features: list[ModelFeature] | None = None) -> AIModelEntity:
return AIModelEntity(
model="gpt-4o-mini",
label=I18nObject(en_US="GPT-4o mini"),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
features=features,
)
class _ModelTypeInstanceStub(LargeLanguageModel):
def __init__(
self,
*,
model_schema: AIModelEntity | None,
model_runtime: object | None = None,
) -> None:
self.model_runtime = model_runtime
self.get_model_schema = Mock(return_value=model_schema)
class _ModelInstanceStub:
def __init__(
self,
*,
model_schema: AIModelEntity | None,
model_runtime: object | None = None,
invoke_llm_result: object = sentinel.result,
get_llm_num_tokens_result: int = 32,
) -> None:
self.provider = "langgenius/openai/openai"
self.model_name = "gpt-4o-mini"
self.parameters = {"temperature": 0.2}
self.stop = ("stop",)
self.credentials = {"api_key": "secret"}
self.model_type_instance = _ModelTypeInstanceStub(
model_schema=model_schema,
model_runtime=model_runtime,
)
self.get_llm_num_tokens = Mock(return_value=get_llm_num_tokens_result)
self.invoke_llm = Mock(return_value=invoke_llm_result)
def _build_run_context(*, invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER) -> dict[str, object]:
return build_test_run_context(
tenant_id="tenant-id",
@ -126,17 +169,8 @@ def test_dify_file_reference_factory_passes_tenant_id(monkeypatch: pytest.Monkey
def test_dify_prepared_llm_wraps_model_instance_calls() -> None:
model_schema = _build_model_schema()
model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value=model_schema))
model_instance = SimpleNamespace(
provider="langgenius/openai/openai",
model_name="gpt-4o-mini",
parameters={"temperature": 0.2},
stop=("stop",),
credentials={"api_key": "secret"},
model_type_instance=model_type_instance,
get_llm_num_tokens=Mock(return_value=32),
invoke_llm=Mock(return_value=sentinel.result),
)
model_instance = _ModelInstanceStub(model_schema=model_schema)
model_type_instance = model_instance.model_type_instance
prepared = DifyPreparedLLM(model_instance)
assert prepared.provider == "langgenius/openai/openai"
@ -167,11 +201,8 @@ def test_dify_prepared_llm_wraps_model_instance_calls() -> None:
def test_dify_prepared_llm_requires_model_schema() -> None:
model_instance = SimpleNamespace(
model_name="gpt-4o-mini",
credentials={},
model_type_instance=SimpleNamespace(get_model_schema=Mock(return_value=None)),
)
model_instance = _ModelInstanceStub(model_schema=None)
model_instance.credentials = {}
prepared = DifyPreparedLLM(model_instance)
with pytest.raises(ValueError, match="Model schema not found"):
@ -179,12 +210,7 @@ def test_dify_prepared_llm_requires_model_schema() -> None:
def test_dify_prepared_llm_delegates_structured_output_helper(monkeypatch: pytest.MonkeyPatch) -> None:
model_instance = SimpleNamespace(
provider="langgenius/openai/openai",
model_name="gpt-4o-mini",
credentials={"api_key": "secret"},
model_type_instance=SimpleNamespace(get_model_schema=Mock(return_value=_build_model_schema())),
)
model_instance = _ModelInstanceStub(model_schema=_build_model_schema())
prepared = DifyPreparedLLM(model_instance)
invoke_structured = MagicMock(return_value=sentinel.structured)
monkeypatch.setattr(node_runtime, "invoke_llm_with_structured_output", invoke_structured)
@ -217,6 +243,94 @@ def test_dify_prepared_llm_identifies_structured_output_errors() -> None:
assert prepared.is_structured_output_parse_error(ValueError("other")) is False
def test_dify_prepared_polling_llm_delegates_to_plugin_runtime() -> None:
polling_result = LLMPollingResult(
status=LLMPollingStatus.RUNNING,
plugin_state={"task_id": "poll-1"},
next_check_after_seconds=2,
)
plugin_runtime = PluginModelRuntime(
tenant_id="tenant-id",
user_id="user-id",
client=Mock(spec=PluginModelClient),
plugin_service=PluginService,
)
plugin_runtime.start_llm_polling = Mock(return_value=polling_result) # type: ignore[method-assign]
plugin_runtime.check_llm_polling = Mock(return_value=polling_result) # type: ignore[method-assign]
model_instance = _ModelInstanceStub(
model_schema=_build_model_schema(features=[ModelFeature.POLLING]),
model_runtime=plugin_runtime,
)
prepared = DifyPreparedPollingLLM(model_instance)
assert isinstance(prepared, LLMPollingCapableProtocol)
assert (
prepared.start_llm_polling(
prompt_messages=[],
model_parameters={"temperature": 0.1},
tools=[],
stop=("END",),
json_schema={"type": "object"},
)
== polling_result
)
assert (
prepared.check_llm_polling(
plugin_state={"task_id": "poll-1"},
)
== polling_result
)
plugin_runtime.start_llm_polling.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
prompt_messages=[],
model_parameters={"temperature": 0.1},
tools=[],
stop=("END",),
json_schema={"type": "object"},
)
plugin_runtime.check_llm_polling.assert_called_once_with(
provider="langgenius/openai/openai",
model="gpt-4o-mini",
credentials={"api_key": "secret"},
plugin_state={"task_id": "poll-1"},
)
def test_dify_prepared_polling_llm_raise_exception_when_polling_is_unsupported() -> None:
llm_result = node_runtime.LLMResult(
model="gpt-4o-mini",
prompt_messages=[],
message=AssistantPromptMessage(content="sync-result"),
usage=node_runtime.LLMUsage.empty_usage(),
)
plugin_runtime = PluginModelRuntime(
tenant_id="tenant-id",
user_id="user-id",
client=Mock(),
plugin_service=Mock(),
)
plugin_runtime.start_llm_polling = Mock(side_effect=PluginLLMPollingUnsupportedError("Polling unsupported")) # type: ignore[method-assign]
model_instance = _ModelInstanceStub(
model_schema=_build_model_schema(features=[ModelFeature.POLLING]),
model_runtime=plugin_runtime,
invoke_llm_result=llm_result,
)
prepared = DifyPreparedPollingLLM(model_instance)
with pytest.raises(PluginLLMPollingUnsupportedError):
prepared.start_llm_polling(
prompt_messages=[],
model_parameters={"temperature": 0.1},
tools=None,
stop=None,
json_schema=None,
)
def test_dify_prompt_message_serializer_delegates(monkeypatch: pytest.MonkeyPatch) -> None:
serialize = MagicMock(return_value={"prompt": "value"})
monkeypatch.setattr(node_runtime.PromptMessageUtil, "prompt_messages_to_prompt_for_saving", serialize)

10
api/uv.lock generated
View File

@ -1293,7 +1293,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "fastapi", marker = "extra == 'server'", specifier = "==0.136.0" },
{ name = "graphon", marker = "extra == 'server'", specifier = "==0.5.1" },
{ name = "graphon", marker = "extra == 'server'", specifier = "==0.5.2" },
{ name = "grpclib", extras = ["protobuf"], marker = "extra == 'grpc'", specifier = ">=0.4.9,<0.5.0" },
{ name = "httpx", specifier = "==0.28.1" },
{ name = "jsonschema", marker = "extra == 'server'", specifier = ">=4.23.0,<5.0.0" },
@ -1636,7 +1636,7 @@ requires-dist = [
{ name = "gmpy2", specifier = ">=2.3.0,<3.0.0" },
{ name = "google-api-python-client", specifier = ">=2.196.0,<3.0.0" },
{ name = "google-cloud-aiplatform", specifier = ">=1.151.0,<2.0.0" },
{ name = "graphon", specifier = "==0.5.1" },
{ name = "graphon", specifier = "==0.5.2" },
{ name = "gunicorn", specifier = ">=26.0.0,<27.0.0" },
{ name = "httpx", extras = ["socks"], specifier = "==0.28.1" },
{ name = "httpx-sse", specifier = "==0.4.3" },
@ -2987,7 +2987,7 @@ httpx = [
[[package]]
name = "graphon"
version = "0.5.1"
version = "0.5.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "charset-normalizer" },
@ -3008,9 +3008,9 @@ dependencies = [
{ name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] },
{ name = "webvtt-py" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a2/fa/432fa802bcb13f7f51dc323ddef92594b15333eafef181d937ffa554116e/graphon-0.5.1.tar.gz", hash = "sha256:ca38cc62ef3fbc2f3072b68235bcb41e32a6369a1753b46418c1d761c57125fe", size = 269741, upload-time = "2026-06-11T03:01:38.197Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c2/16/f183da187414c335be67f52f6a1b7c2a33bf0b1d5090eda7e6c92d42d94a/graphon-0.5.2.tar.gz", hash = "sha256:d66a9edcd883766bd50e94f84a691c92ce536ea60e721552089e83ac8e94bf68", size = 269773, upload-time = "2026-06-16T04:06:22.074Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/c5/61e8634b89c320af9453083213e8be436071634dbc69cb14b5fe646763e4/graphon-0.5.1-py3-none-any.whl", hash = "sha256:70b49c244a46fb6e338905210cc895bd67584d9ab1412f6ba3cd4ed284010091", size = 381866, upload-time = "2026-06-11T03:01:36.693Z" },
{ url = "https://files.pythonhosted.org/packages/2f/e6/36a3981cd44e7a40a7cd7d374e26f01e02dd49410c5fbbd7df248750d5fb/graphon-0.5.2-py3-none-any.whl", hash = "sha256:11f89399e67ed1ddd2ce1c336accd9c4ad5b8fe2741f9167e6085af0b325cd14", size = 381908, upload-time = "2026-06-16T04:06:20.453Z" },
]
[[package]]

View File

@ -20,7 +20,7 @@ dify-agent-stub-server = "dify_agent.agent_stub.server.cli:main"
grpc = ["grpclib[protobuf]>=0.4.9,<0.5.0", "protobuf>=6.33.5,<7.0.0"]
server = [
"fastapi==0.136.0",
"graphon==0.5.1",
"graphon==0.5.2",
"jsonschema>=4.23.0,<5.0.0",
"jwcrypto>=1.5.6,<2",
"pydantic-ai-slim[anthropic,google,openai]>=1.85.1,<2.0.0",

View File

@ -16,7 +16,7 @@ CLIENT_SHARED_DTO_DEPENDENCIES = {
SERVER_RUNTIME_DEPENDENCIES = {
"fastapi==0.136.0",
"graphon==0.5.1",
"graphon==0.5.2",
"jsonschema>=4.23.0,<5.0.0",
"jwcrypto>=1.5.6,<2",
"pydantic-ai-slim[anthropic,google,openai]>=1.85.1,<2.0.0",

8
dify-agent/uv.lock generated
View File

@ -628,7 +628,7 @@ docs = [
[package.metadata]
requires-dist = [
{ name = "fastapi", marker = "extra == 'server'", specifier = "==0.136.0" },
{ name = "graphon", marker = "extra == 'server'", specifier = "==0.5.1" },
{ name = "graphon", marker = "extra == 'server'", specifier = "==0.5.2" },
{ name = "grpclib", extras = ["protobuf"], marker = "extra == 'grpc'", specifier = ">=0.4.9,<0.5.0" },
{ name = "httpx", specifier = "==0.28.1" },
{ name = "jsonschema", marker = "extra == 'server'", specifier = ">=4.23.0,<5.0.0" },
@ -808,7 +808,7 @@ wheels = [
[[package]]
name = "graphon"
version = "0.5.1"
version = "0.5.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "charset-normalizer" },
@ -829,9 +829,9 @@ dependencies = [
{ name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] },
{ name = "webvtt-py" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a2/fa/432fa802bcb13f7f51dc323ddef92594b15333eafef181d937ffa554116e/graphon-0.5.1.tar.gz", hash = "sha256:ca38cc62ef3fbc2f3072b68235bcb41e32a6369a1753b46418c1d761c57125fe", size = 269741, upload-time = "2026-06-11T03:01:38.197Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c2/16/f183da187414c335be67f52f6a1b7c2a33bf0b1d5090eda7e6c92d42d94a/graphon-0.5.2.tar.gz", hash = "sha256:d66a9edcd883766bd50e94f84a691c92ce536ea60e721552089e83ac8e94bf68", size = 269773, upload-time = "2026-06-16T04:06:22.074Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/c5/61e8634b89c320af9453083213e8be436071634dbc69cb14b5fe646763e4/graphon-0.5.1-py3-none-any.whl", hash = "sha256:70b49c244a46fb6e338905210cc895bd67584d9ab1412f6ba3cd4ed284010091", size = 381866, upload-time = "2026-06-11T03:01:36.693Z" },
{ url = "https://files.pythonhosted.org/packages/2f/e6/36a3981cd44e7a40a7cd7d374e26f01e02dd49410c5fbbd7df248750d5fb/graphon-0.5.2-py3-none-any.whl", hash = "sha256:11f89399e67ed1ddd2ce1c336accd9c4ad5b8fe2741f9167e6085af0b325cd14", size = 381908, upload-time = "2026-06-16T04:06:20.453Z" },
]
[[package]]