mirror of https://github.com/langgenius/dify.git
refactor(workflow): add Jinja2 renderer abstraction for template transform (#30535)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
154abdd915
commit
95edbad1c7
|
|
@ -11,6 +11,11 @@ from core.workflow.graph import NodeFactory
|
|||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
|
@ -37,6 +42,7 @@ class DifyNodeFactory(NodeFactory):
|
|||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
|
@ -54,6 +60,7 @@ class DifyNodeFactory(NodeFactory):
|
|||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: dict[str, object]) -> Node:
|
||||
|
|
@ -107,6 +114,15 @@ class DifyNodeFactory(NodeFactory):
|
|||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
class TemplateRenderError(ValueError):
|
||||
"""Raised when rendering a Jinja2 template fails."""
|
||||
|
||||
|
||||
class Jinja2TemplateRenderer(Protocol):
|
||||
"""Render Jinja2 templates for template transform nodes."""
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
"""Render a Jinja2 template with provided variables."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
|
||||
"""Adapter that renders Jinja2 templates via CodeExecutor."""
|
||||
|
||||
_code_executor: type[CodeExecutor]
|
||||
|
||||
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
|
||||
self._code_executor = code_executor or CodeExecutor
|
||||
|
||||
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
|
||||
try:
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as exc:
|
||||
raise TemplateRenderError(str(exc)) from exc
|
||||
|
||||
rendered = result.get("result")
|
||||
if not isinstance(rendered, str):
|
||||
raise TemplateRenderError("Template render result must be a string.")
|
||||
return rendered
|
||||
|
|
@ -1,18 +1,44 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
TemplateRenderError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
|
@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
|||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
rendered = self._template_renderer.render_template(self.node_data.template, variables)
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
|||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
|
@ -127,7 +127,9 @@ class TestTemplateTransformNode:
|
|||
"""Test version class method."""
|
||||
assert TemplateTransformNode.version() == "1"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_simple_template(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
|
|
@ -145,7 +147,7 @@ class TestTemplateTransformNode:
|
|||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
|
||||
# Setup mock executor
|
||||
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
|
||||
mock_execute.return_value = "Hello Alice, you are 30 years old!"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -162,7 +164,9 @@ class TestTemplateTransformNode:
|
|||
assert result.inputs["name"] == "Alice"
|
||||
assert result.inputs["age"] == 30
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with None variable values."""
|
||||
node_data = {
|
||||
|
|
@ -172,7 +176,7 @@ class TestTemplateTransformNode:
|
|||
}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = None
|
||||
mock_execute.return_value = {"result": "Value: "}
|
||||
mock_execute.return_value = "Value: "
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -187,13 +191,15 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.inputs["value"] is None
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_code_execution_error(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when code execution fails."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.side_effect = CodeExecutionError("Template syntax error")
|
||||
mock_execute.side_effect = TemplateRenderError("Template syntax error")
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -208,14 +214,16 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Template syntax error" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
|
||||
def test_run_output_length_exceeds_limit(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
"""Test _run when output exceeds maximum length."""
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
|
||||
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
|
||||
mock_execute.return_value = "This is a very long output that exceeds the limit"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -230,7 +238,9 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Output length exceeds" in result.error
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_complex_jinja2_template(
|
||||
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
):
|
||||
|
|
@ -257,7 +267,7 @@ class TestTemplateTransformNode:
|
|||
("sys", "show_total"): mock_show_total,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
|
||||
mock_execute.return_value = "apple, banana, orange (Total: 3)"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -292,7 +302,9 @@ class TestTemplateTransformNode:
|
|||
assert mapping["node_123.var1"] == ["sys", "input1"]
|
||||
assert mapping["node_123.var2"] == ["sys", "input2"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with no variables (static template)."""
|
||||
node_data = {
|
||||
|
|
@ -301,7 +313,7 @@ class TestTemplateTransformNode:
|
|||
"template": "This is a static message.",
|
||||
}
|
||||
|
||||
mock_execute.return_value = {"result": "This is a static message."}
|
||||
mock_execute.return_value = "This is a static message."
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -317,7 +329,9 @@ class TestTemplateTransformNode:
|
|||
assert result.outputs["output"] == "This is a static message."
|
||||
assert result.inputs == {}
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with numeric variable values."""
|
||||
node_data = {
|
||||
|
|
@ -339,7 +353,7 @@ class TestTemplateTransformNode:
|
|||
("sys", "quantity"): mock_quantity,
|
||||
}
|
||||
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
|
||||
mock_execute.return_value = {"result": "Total: $31.5"}
|
||||
mock_execute.return_value = "Total: $31.5"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -354,7 +368,9 @@ class TestTemplateTransformNode:
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["output"] == "Total: $31.5"
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with dictionary variable values."""
|
||||
node_data = {
|
||||
|
|
@ -367,7 +383,7 @@ class TestTemplateTransformNode:
|
|||
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_user
|
||||
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
|
||||
mock_execute.return_value = "Name: John Doe, Email: john@example.com"
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
@ -383,7 +399,9 @@ class TestTemplateTransformNode:
|
|||
assert "John Doe" in result.outputs["output"]
|
||||
assert "john@example.com" in result.outputs["output"]
|
||||
|
||||
@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
)
|
||||
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with list variable values."""
|
||||
node_data = {
|
||||
|
|
@ -396,7 +414,7 @@ class TestTemplateTransformNode:
|
|||
mock_tags.to_object.return_value = ["python", "ai", "workflow"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
|
||||
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
|
||||
mock_execute.return_value = "Tags: #python #ai #workflow "
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id="test_node",
|
||||
|
|
|
|||
Loading…
Reference in New Issue