refactor(workflow): add Jinja2 renderer abstraction for template transform (#30535)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
-LAN- 2026-01-05 10:46:37 +08:00 committed by GitHub
parent 154abdd915
commit 95edbad1c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 125 additions and 27 deletions

View File

@ -11,6 +11,11 @@ from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -37,6 +42,7 @@ class DifyNodeFactory(NodeFactory):
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@ -54,6 +60,7 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -107,6 +114,15 @@ class DifyNodeFactory(NodeFactory):
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod

View File

@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.helper.code_executor.code_executor import CodeExecutionError
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.workflow import WorkflowType
@ -127,7 +127,9 @@ class TestTemplateTransformNode:
"""Test version class method."""
assert TemplateTransformNode.version() == "1"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_simple_template(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
@ -145,7 +147,7 @@ class TestTemplateTransformNode:
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
# Setup mock executor
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
mock_execute.return_value = "Hello Alice, you are 30 years old!"
node = TemplateTransformNode(
id="test_node",
@ -162,7 +164,9 @@ class TestTemplateTransformNode:
assert result.inputs["name"] == "Alice"
assert result.inputs["age"] == 30
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with None variable values."""
node_data = {
@ -172,7 +176,7 @@ class TestTemplateTransformNode:
}
mock_graph_runtime_state.variable_pool.get.return_value = None
mock_execute.return_value = {"result": "Value: "}
mock_execute.return_value = "Value: "
node = TemplateTransformNode(
id="test_node",
@ -187,13 +191,15 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.inputs["value"] is None
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_code_execution_error(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when code execution fails."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.side_effect = CodeExecutionError("Template syntax error")
mock_execute.side_effect = TemplateRenderError("Template syntax error")
node = TemplateTransformNode(
id="test_node",
@ -208,14 +214,16 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Template syntax error" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when output exceeds maximum length."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
mock_execute.return_value = "This is a very long output that exceeds the limit"
node = TemplateTransformNode(
id="test_node",
@ -230,7 +238,9 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Output length exceeds" in result.error
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_complex_jinja2_template(
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
):
@ -257,7 +267,7 @@ class TestTemplateTransformNode:
("sys", "show_total"): mock_show_total,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
mock_execute.return_value = "apple, banana, orange (Total: 3)"
node = TemplateTransformNode(
id="test_node",
@ -292,7 +302,9 @@ class TestTemplateTransformNode:
assert mapping["node_123.var1"] == ["sys", "input1"]
assert mapping["node_123.var2"] == ["sys", "input2"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with no variables (static template)."""
node_data = {
@ -301,7 +313,7 @@ class TestTemplateTransformNode:
"template": "This is a static message.",
}
mock_execute.return_value = {"result": "This is a static message."}
mock_execute.return_value = "This is a static message."
node = TemplateTransformNode(
id="test_node",
@ -317,7 +329,9 @@ class TestTemplateTransformNode:
assert result.outputs["output"] == "This is a static message."
assert result.inputs == {}
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with numeric variable values."""
node_data = {
@ -339,7 +353,7 @@ class TestTemplateTransformNode:
("sys", "quantity"): mock_quantity,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "Total: $31.5"}
mock_execute.return_value = "Total: $31.5"
node = TemplateTransformNode(
id="test_node",
@ -354,7 +368,9 @@ class TestTemplateTransformNode:
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "Total: $31.5"
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with dictionary variable values."""
node_data = {
@ -367,7 +383,7 @@ class TestTemplateTransformNode:
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
mock_execute.return_value = "Name: John Doe, Email: john@example.com"
node = TemplateTransformNode(
id="test_node",
@ -383,7 +399,9 @@ class TestTemplateTransformNode:
assert "John Doe" in result.outputs["output"]
assert "john@example.com" in result.outputs["output"]
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with list variable values."""
node_data = {
@ -396,7 +414,7 @@ class TestTemplateTransformNode:
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
mock_execute.return_value = "Tags: #python #ai #workflow "
node = TemplateTransformNode(
id="test_node",