refactor: TemplateTransformNode decouple code executor (#32879)

This commit is contained in:
wangxiaolei 2026-03-03 13:36:17 +08:00 committed by GitHub
parent 4c07bc99f7
commit 1b2234a19f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 130 additions and 122 deletions

View File

@ -134,7 +134,6 @@ ignore_imports =
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager

View File

@ -119,7 +119,7 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager

View File

@ -3,7 +3,8 @@ from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from dify_graph.nodes.code.code_node import WorkflowCodeExecutor
from dify_graph.nodes.code.entities import CodeLanguage
class TemplateRenderError(ValueError):
@ -21,18 +22,18 @@ class Jinja2TemplateRenderer(Protocol):
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
_code_executor: WorkflowCodeExecutor
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def __init__(self, code_executor: WorkflowCodeExecutor) -> None:
self._code_executor = code_executor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables)
except Exception as exc:
if self._code_executor.is_execution_error(exc):
raise TemplateRenderError(str(exc)) from exc
raise
rendered = result.get("result")
if not isinstance(rendered, str):

View File

@ -6,7 +6,6 @@ from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData
from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
@ -30,7 +29,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
template_renderer: Jinja2TemplateRenderer,
max_output_length: int | None = None,
) -> None:
super().__init__(
@ -39,7 +38,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._template_renderer = template_renderer
if max_output_length is not None and max_output_length <= 0:
raise ValueError("max_output_length must be a positive integer")

View File

@ -1,22 +1,31 @@
import time
import uuid
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.entities import GraphInitParams
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code(setup_code_executor_mock):
class _SimpleJinja2Renderer:
"""Minimal Jinja2-based renderer for integration tests (no code executor)."""
def render_template(self, template: str, variables: dict[str, object]) -> str:
from jinja2 import Template
try:
return Template(template).render(**variables)
except Exception as exc:
raise TemplateRenderError(str(exc)) from exc
def test_execute_template_transform():
code = """{{args2}}"""
config = {
"id": "1",
@ -68,19 +77,21 @@ def test_execute_code(setup_code_executor_mock):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
# Create node factory (graph init path still works regardless of renderer choice below)
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
assert graph is not None
node = TemplateTransformNode(
id=str(uuid.uuid4()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
template_renderer=_SimpleJinja2Renderer(),
)
# execute node

View File

@ -24,6 +24,10 @@ from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.question_classifier import QuestionClassifierNode
from dify_graph.nodes.template_transform import TemplateTransformNode
from dify_graph.nodes.template_transform.template_renderer import (
Jinja2TemplateRenderer,
TemplateRenderError,
)
from dify_graph.nodes.tool import ToolNode
if TYPE_CHECKING:
@ -33,6 +37,18 @@ if TYPE_CHECKING:
from .test_mock_config import MockConfig
class _TestJinja2Renderer(Jinja2TemplateRenderer):
"""Simple Jinja2 renderer for tests (avoids code executor)."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
from jinja2 import Template as _Jinja2Template
try:
return _Jinja2Template(template).render(**variables)
except Exception as exc: # pragma: no cover - pass through as contract error
raise TemplateRenderError(str(exc)) from exc
class MockNodeMixin:
"""Mixin providing common mock functionality."""
@ -50,6 +66,10 @@ class MockNodeMixin:
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
# Ensure TemplateTransformNode receives a renderer now required by constructor
if isinstance(self, TemplateTransformNode):
kwargs.setdefault("template_renderer", _TestJinja2Renderer())
super().__init__(
id=id,
config=config,

View File

@ -1,13 +1,15 @@
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from dify_graph.graph_engine.entities.graph import Graph
from dify_graph.graph_engine.entities.graph_init_params import GraphInitParams
from dify_graph.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.app.entities.app_invoke_entities import InvokeFrom
from dify_graph.entities import GraphInitParams
from dify_graph.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
from dify_graph.runtime import GraphRuntimeState
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -24,7 +26,7 @@ class TestTemplateTransformNode:
@pytest.fixture
def mock_graph(self):
"""Create a mock Graph."""
"""Create a mock Graph (kept for backward compat in other tests)."""
return MagicMock(spec=Graph)
@pytest.fixture
@ -37,8 +39,8 @@ class TestTemplateTransformNode:
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from="test",
invoke_from="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
@ -55,14 +57,15 @@ class TestTemplateTransformNode:
"template": "Hello {{ name }}, you are {{ age }} years old!",
}
def test_node_initialization(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test that TemplateTransformNode initializes correctly."""
mock_renderer = MagicMock()
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
config={"id": "test_node", "data": basic_node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
assert node.node_type == NodeType.TEMPLATE_TRANSFORM
@ -70,31 +73,33 @@ class TestTemplateTransformNode:
assert len(node._node_data.variables) == 2
assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!"
def test_get_title(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_title method."""
mock_renderer = MagicMock()
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
config={"id": "test_node", "data": basic_node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
assert node._get_title() == "Template Transform"
def test_get_description(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_description method."""
mock_renderer = MagicMock()
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
config={"id": "test_node", "data": basic_node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
assert node._get_description() == "Transform data using template"
def test_get_error_strategy(self, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_get_error_strategy(self, mock_graph_runtime_state, graph_init_params):
"""Test _get_error_strategy method."""
node_data = {
"title": "Test",
@ -103,12 +108,13 @@ class TestTemplateTransformNode:
"error_strategy": "fail-branch",
}
mock_renderer = MagicMock()
node = TemplateTransformNode(
id="test_node",
config=node_data,
config={"id": "test_node", "data": node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH
@ -127,14 +133,8 @@ class TestTemplateTransformNode:
"""Test version class method."""
assert TemplateTransformNode.version() == "1"
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_simple_template(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run with simple template transformation."""
def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _run with simple template transformation using injected renderer."""
# Setup mock variable pool
mock_name_value = MagicMock()
mock_name_value.to_object.return_value = "Alice"
@ -147,15 +147,16 @@ class TestTemplateTransformNode:
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
# Setup mock executor
mock_execute.return_value = "Hello Alice, you are 30 years old!"
# Setup mock renderer
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!"
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
config={"id": "test_node", "data": basic_node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()
@ -165,11 +166,7 @@ class TestTemplateTransformNode:
assert result.inputs["name"] == "Alice"
assert result.inputs["age"] == 30
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_run_with_none_values(self, mock_graph_runtime_state, graph_init_params):
"""Test _run with None variable values."""
node_data = {
"title": "Test",
@ -178,14 +175,16 @@ class TestTemplateTransformNode:
}
mock_graph_runtime_state.variable_pool.get.return_value = None
mock_execute.return_value = "Value: "
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Value: "
node = TemplateTransformNode(
id="test_node",
config=node_data,
config={"id": "test_node", "data": node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()
@ -193,23 +192,19 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.inputs["value"] is None
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_code_execution_error(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when code execution fails."""
def test_run_with_render_error(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _run when template rendering fails."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.side_effect = TemplateRenderError("Template syntax error")
mock_renderer = MagicMock()
mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error")
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
config={"id": "test_node", "data": basic_node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()
@ -217,23 +212,19 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Template syntax error" in result.error
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
def test_run_output_length_exceeds_limit(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _run when output exceeds maximum length."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.return_value = "This is a very long output that exceeds the limit"
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit"
node = TemplateTransformNode(
id="test_node",
config=basic_node_data,
config={"id": "test_node", "data": basic_node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
max_output_length=10,
)
@ -242,13 +233,7 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Output length exceeds" in result.error
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_complex_jinja2_template(
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
):
def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params):
"""Test _run with complex Jinja2 template including loops and conditions."""
node_data = {
"title": "Complex Template",
@ -272,14 +257,16 @@ class TestTemplateTransformNode:
("sys", "show_total"): mock_show_total,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = "apple, banana, orange (Total: 3)"
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)"
node = TemplateTransformNode(
id="test_node",
config=node_data,
config={"id": "test_node", "data": node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()
@ -307,11 +294,7 @@ class TestTemplateTransformNode:
assert mapping["node_123.var1"] == ["sys", "input1"]
assert mapping["node_123.var2"] == ["sys", "input2"]
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params):
"""Test _run with no variables (static template)."""
node_data = {
"title": "Static Template",
@ -319,14 +302,15 @@ class TestTemplateTransformNode:
"template": "This is a static message.",
}
mock_execute.return_value = "This is a static message."
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a static message."
node = TemplateTransformNode(
id="test_node",
config=node_data,
config={"id": "test_node", "data": node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()
@ -335,11 +319,7 @@ class TestTemplateTransformNode:
assert result.outputs["output"] == "This is a static message."
assert result.inputs == {}
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_run_with_numeric_values(self, mock_graph_runtime_state, graph_init_params):
"""Test _run with numeric variable values."""
node_data = {
"title": "Numeric Template",
@ -360,14 +340,16 @@ class TestTemplateTransformNode:
("sys", "quantity"): mock_quantity,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = "Total: $31.5"
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Total: $31.5"
node = TemplateTransformNode(
id="test_node",
config=node_data,
config={"id": "test_node", "data": node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()
@ -375,11 +357,7 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "Total: $31.5"
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_run_with_dict_values(self, mock_graph_runtime_state, graph_init_params):
"""Test _run with dictionary variable values."""
node_data = {
"title": "Dict Template",
@ -391,14 +369,16 @@ class TestTemplateTransformNode:
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
mock_execute.return_value = "Name: John Doe, Email: john@example.com"
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com"
node = TemplateTransformNode(
id="test_node",
config=node_data,
config={"id": "test_node", "data": node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()
@ -407,11 +387,7 @@ class TestTemplateTransformNode:
assert "John Doe" in result.outputs["output"]
assert "john@example.com" in result.outputs["output"]
@patch(
"dify_graph.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
def test_run_with_list_values(self, mock_graph_runtime_state, graph_init_params):
"""Test _run with list variable values."""
node_data = {
"title": "List Template",
@ -423,14 +399,16 @@ class TestTemplateTransformNode:
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
mock_execute.return_value = "Tags: #python #ai #workflow "
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Tags: #python #ai #workflow "
node = TemplateTransformNode(
id="test_node",
config=node_data,
config={"id": "test_node", "data": node_data},
graph_init_params=graph_init_params,
graph=mock_graph,
graph_runtime_state=mock_graph_runtime_state,
template_renderer=mock_renderer,
)
result = node._run()