diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 4a7918032e..69be25bfb5 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -5,7 +5,7 @@ This layer centralizes model-quota deduction outside node implementations. """ import logging -from typing import TYPE_CHECKING, cast, final, override +from typing import final, override from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available @@ -17,11 +17,6 @@ from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.nodes.base.node import Node -if TYPE_CHECKING: - from graphon.nodes.llm.node import LLMNode - from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode - logger = logging.getLogger(__name__) @@ -109,16 +104,14 @@ class LLMQuotaLayer(GraphEngineLayer): @staticmethod def _extract_model_instance(node: Node) -> ModelInstance | None: + match node.node_type: + case BuiltinNodeTypes.LLM | BuiltinNodeTypes.PARAMETER_EXTRACTOR | BuiltinNodeTypes.QUESTION_CLASSIFIER: + pass + case _: + return None + try: - match node.node_type: - case BuiltinNodeTypes.LLM: - model_instance = cast("LLMNode", node).model_instance - case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - model_instance = cast("ParameterExtractorNode", node).model_instance - case BuiltinNodeTypes.QUESTION_CLASSIFIER: - model_instance = cast("QuestionClassifierNode", node).model_instance - case _: - return None + model_instance = getattr(node, "model_instance", None) except AttributeError: logger.warning( "LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s", @@ -133,4 +126,12 @@ class LLMQuotaLayer(GraphEngineLayer): if isinstance(raw_model_instance, ModelInstance): return raw_model_instance + private_model_instance = getattr(node, "_model_instance", None) + if isinstance(private_model_instance, ModelInstance): + return private_model_instance + + wrapped_private_model_instance = getattr(private_model_instance, "_model_instance", None) + if isinstance(wrapped_private_model_instance, ModelInstance): + return wrapped_private_model_instance + return None diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index a68220e682..209e37ad6c 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -4,22 +4,31 @@ import hashlib import logging from collections.abc import Generator, Iterable, Sequence from threading import Lock -from typing import IO, Any, Union +from typing import IO, Any, Literal, cast, overload from pydantic import ValidationError from redis import RedisError from configs import dify_config +from core.llm_generator.output_parser.structured_output import ( + invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper, +) from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result from graphon.model_runtime.protocols.runtime import ModelRuntime from models.provider_ids import ModelProviderID @@ -29,6 +38,57 @@ logger = logging.getLogger(__name__) TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" +class _PluginStructuredOutputModelInstance: + """Reuse the shared structured-output helper without depending on `ModelInstance`.""" + + def __init__( + self, + *, + runtime: PluginModelRuntime, + provider: str, + model: str, + credentials: dict[str, Any], + ) -> None: + self._runtime = runtime + self._provider = provider + self._model = model + self._credentials = credentials + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: dict[str, Any] | None = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, + stream: bool = True, + callbacks: object | None = None, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + del callbacks + if stream: + return self._runtime.invoke_llm( + provider=self._provider, + model=self._model, + credentials=self._credentials, + model_parameters=model_parameters or {}, + prompt_messages=prompt_messages, + tools=list(tools) if tools else None, + stop=stop, + stream=True, + ) + + return self._runtime.invoke_llm( + provider=self._provider, + model=self._model, + credentials=self._credentials, + model_parameters=model_parameters or {}, + prompt_messages=prompt_messages, + tools=list(tools) if tools else None, + stop=stop, + stream=False, + ) + + class PluginModelRuntime(ModelRuntime): """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" @@ -195,6 +255,34 @@ class PluginModelRuntime(ModelRuntime): return schema + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -206,9 +294,9 @@ class PluginModelRuntime(ModelRuntime): tools: list[PromptMessageTool] | None, stop: Sequence[str] | None, stream: bool, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + ) -> LLMResult | Generator[LLMResultChunk, None, None]: plugin_id, provider_name = self._split_provider(provider) - return self.client.invoke_llm( + result = self.client.invoke_llm( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, @@ -221,6 +309,81 @@ class PluginModelRuntime(ModelRuntime): stop=list(stop) if stop else None, stream=stream, ) + if stream: + return result + + return normalize_non_stream_runtime_result( + model=model, + prompt_messages=prompt_messages, + result=result, + ) + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + model_schema = self.get_model_schema( + provider=provider, + model_type=ModelType.LLM, + model=model, + credentials=credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {model}") + + adapter = _PluginStructuredOutputModelInstance( + runtime=self, + provider=provider, + model=model, + credentials=credentials, + ) + return invoke_llm_with_structured_output_helper( + provider=provider, + model_schema=model_schema, + model_instance=cast(Any, adapter), + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + tools=None, + stop=list(stop) if stop else None, + stream=stream, + ) def get_llm_num_tokens( self, diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py index 9d15a3fcea..77ef3826e9 100644 --- a/api/core/workflow/system_variables.py +++ b/api/core/workflow/system_variables.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any, Protocol, cast +from typing import Any, Protocol from uuid import uuid4 from graphon.enums import BuiltinNodeTypes @@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: normalized = _normalize_system_variable_values(values, **kwargs) return [ - cast( - Variable, - segment_to_variable( - segment=build_segment(value), - selector=system_variable_selector(key), - name=key, - ), + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, ) for key, value in normalized.items() ] @@ -130,13 +127,10 @@ def build_bootstrap_variables( for node_id, value in rag_pipeline_variables_map.items(): variables.append( - cast( - Variable, - segment_to_variable( - segment=build_segment(value), - selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), - name=node_id, - ), + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, ) ) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index a55448e352..59db147576 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -157,8 +157,8 @@ class DraftVarLoader(VariableLoader): # This approach reduces loading time by querying external systems concurrently. with ThreadPoolExecutor(max_workers=10) as executor: offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars) - for selector, variable in offloaded_variables: - variable_by_selector[selector] = variable + for selector, offloaded_variable in offloaded_variables: + variable_by_selector[selector] = offloaded_variable return list(variable_by_selector.values()) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 86ebb8800a..9345113aa3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -45,7 +45,7 @@ def init_code_node(code_config: dict): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index e5f1626072..7cd7f50b77 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -55,7 +55,7 @@ def init_http_node(config: dict): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.runtime import VariablePool # Create variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], @@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock): ) # Create independent variable pool for this test only - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index cc37a1ff1a..92f3a1926c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode: ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="aaa", app_id=app_id, diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 5de0167a6f..f11188323a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 29737afa4f..80489e6809 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -66,7 +66,7 @@ def test_execute_template_transform(): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 1c49bbdb97..f66f65b978 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -41,7 +41,7 @@ def init_tool_node(config: dict): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 1c5669ba94..ad82b8fe2a 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 620a153204..248fed5388 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -16,7 +16,7 @@ from models.workflow import Workflow def _make_graph_state(): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, environment_variables=[], diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 30a068f4c5..03f6931ffe 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs): self.id = node_id - self.config = config + self.data = data self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state self.kwargs = kwargs diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 1821f72e0c..3daac5f563 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -14,16 +14,11 @@ from graphon.nodes.llm.entities import LLMNodeData from graphon.variables.segments import StringSegment -def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: +def _assert_typed_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None: _ = node_id - if isinstance(config, BaseNodeData): - assert config.type == node_type - assert config.version == version - return - - assert isinstance(config, dict) - assert config["type"] == node_type - assert config["version"] == version + assert isinstance(data, BaseNodeData) + assert data.type == node_type + assert data.version == version def _node_constructor(*, return_value): @@ -470,7 +465,7 @@ class TestDifyNodeFactoryCreateNode: matched_node_class.assert_called_once() kwargs = matched_node_class.call_args.kwargs assert kwargs["node_id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") + _assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is factory.graph_runtime_state latest_node_class.assert_not_called() @@ -490,7 +485,7 @@ class TestDifyNodeFactoryCreateNode: latest_node_class.assert_called_once() kwargs = latest_node_class.call_args.kwargs assert kwargs["node_id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") + _assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is factory.graph_runtime_state @@ -528,7 +523,7 @@ class TestDifyNodeFactoryCreateNode: assert result is created_node kwargs = constructor.call_args.kwargs assert kwargs["node_id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type) + _assert_typed_node_data(kwargs["data"], node_id="node-id", node_type=node_type) assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is factory.graph_runtime_state @@ -597,7 +592,7 @@ class TestDifyNodeFactoryCreateNode: prepared_llm.assert_called_once_with(sentinel.model_instance) assert kwargs["model_instance"] is wrapped_model_instance - def test_create_node_passes_alias_preserving_llm_config_to_constructor(self, monkeypatch, factory): + def test_create_node_passes_typed_llm_data_to_constructor(self, monkeypatch, factory): created_node = object() constructor = _node_constructor(return_value=created_node) monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor)) @@ -625,10 +620,10 @@ class TestDifyNodeFactoryCreateNode: factory.create_node(node_config) - config = constructor.call_args.kwargs["config"] - assert isinstance(config, dict) - assert config["structured_output_enabled"] is True - assert "structured_output_switch_on" not in config + data = constructor.call_args.kwargs["data"] + assert isinstance(data, LLMNodeData) + assert data.structured_output_enabled is True + assert data.structured_output_switch_on is True @pytest.mark.parametrize( ("node_type", "constructor_name", "expected_extra_kwargs"), @@ -707,7 +702,7 @@ class TestDifyNodeFactoryCreateNode: constructor_kwargs = constructor.call_args.kwargs assert constructor_kwargs["node_id"] == "node-id" - _assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type) + _assert_typed_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type) assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9dab38ed8e..0017cd8d3f 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -109,8 +109,8 @@ class TestVariablePool: assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None - def test_constructor_loads_legacy_bootstrap_kwargs(self): - pool = VariablePool( + def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self): + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="test_user_id"), environment_variables=[StringVariable(name="env_var", value="env-value")], conversation_variables=[StringVariable(name="conv_var", value="conv-value")], diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index feafada59a..3d97be51ab 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -2845,7 +2845,7 @@ class TestWorkflowServiceFreeNodeExecution: mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data) mock_node_cls.assert_called_once_with( node_id="n-1", - config=sentinel.node_data, + data=sentinel.node_data, graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value, graph_runtime_state=ANY, runtime=mock_runtime_cls.return_value,