diff --git a/api/.importlinter b/api/.importlinter index 37dbfb15ec..49cf70d61a 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -142,10 +142,6 @@ ignore_imports = core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer core.workflow.nodes.tool.tool_node -> models core.workflow.nodes.agent.agent_node -> models.model - core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider - core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider - core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider - core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy core.workflow.nodes.llm.node -> core.helper.code_executor core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index d02ca1ecbe..41b8c9fd7b 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -6,8 +6,10 @@ from typing_extensions import override from configs import dify_config from core.app.llm.model_access import build_dify_model_access from core.datasource.datasource_manager import DatasourceManager -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor -from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.helper.code_executor.code_executor import ( + CodeExecutionError, + CodeExecutor, +) from core.helper.ssrf_proxy import ssrf_proxy from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelType @@ -80,7 +82,6 @@ class DifyNodeFactory(NodeFactory): self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor() - self._code_providers: tuple[type[CodeNodeProvider], ...] = CodeNode.default_code_providers() self._code_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, @@ -152,7 +153,6 @@ class DifyNodeFactory(NodeFactory): graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, code_executor=self._code_executor, - code_providers=self._code_providers, code_limits=self._code_limits, ) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 73174ed28d..d581b3ac39 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,6 +1,5 @@ import logging from collections.abc import Mapping -from enum import StrEnum from threading import Lock from typing import Any @@ -14,6 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client +from core.workflow.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) @@ -40,12 +40,6 @@ class CodeExecutionResponse(BaseModel): data: Data -class CodeLanguage(StrEnum): - PYTHON3 = "python3" - JINJA2 = "jinja2" - JAVASCRIPT = "javascript" - - def _build_code_executor_client() -> httpx.Client: return httpx.Client( verify=CODE_EXECUTION_SSL_VERIFY, diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index d907ce2120..7b1cbfcfea 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,10 +1,8 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast +from textwrap import dedent +from typing import TYPE_CHECKING, Any, Protocol, cast -from core.helper.code_executor.code_node_provider import CodeNodeProvider -from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider -from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node @@ -36,12 +34,44 @@ class WorkflowCodeExecutor(Protocol): def is_execution_error(self, error: Exception) -> bool: ... +def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]: + return { + "type": "code", + "config": { + "variables": [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ], + "code_language": language, + "code": code, + "outputs": {"result": {"type": "string", "children": None}}, + }, + } + + +_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { + CodeLanguage.PYTHON3: dedent( + """ + def main(arg1: str, arg2: str): + return { + "result": arg1 + arg2, + } + """ + ), + CodeLanguage.JAVASCRIPT: dedent( + """ + function main({arg1, arg2}) { + return { + result: arg1 + arg2 + } + } + """ + ), +} + + class CodeNode(Node[CodeNodeData]): node_type = NodeType.CODE - _DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = ( - Python3CodeProvider, - JavascriptCodeProvider, - ) _limits: CodeNodeLimits def __init__( @@ -52,7 +82,6 @@ class CodeNode(Node[CodeNodeData]): graph_runtime_state: "GraphRuntimeState", *, code_executor: WorkflowCodeExecutor, - code_providers: Sequence[type[CodeNodeProvider]] | None = None, code_limits: CodeNodeLimits, ) -> None: super().__init__( @@ -62,9 +91,6 @@ class CodeNode(Node[CodeNodeData]): graph_runtime_state=graph_runtime_state, ) self._code_executor: WorkflowCodeExecutor = code_executor - self._code_providers: tuple[type[CodeNodeProvider], ...] = ( - tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS - ) self._limits = code_limits @classmethod @@ -78,15 +104,10 @@ class CodeNode(Node[CodeNodeData]): if filters: code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - code_provider: type[CodeNodeProvider] = next( - provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language) - ) - - return code_provider.get_default_config() - - @classmethod - def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]: - return cls._DEFAULT_CODE_PROVIDERS + default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language) + if default_code is None: + raise CodeNodeError(f"Unsupported code language: {code_language}") + return _build_default_config(language=code_language, code=default_code) @classmethod def version(cls) -> str: @@ -108,7 +129,6 @@ class CodeNode(Node[CodeNodeData]): variables[variable_name] = variable.to_object() if variable else None # Run code try: - _ = self._select_code_provider(code_language) result = self._code_executor.execute( language=code_language, code=code, @@ -130,12 +150,6 @@ class CodeNode(Node[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]: - for provider in self._code_providers: - if provider.is_accept_language(code_language): - return provider - raise CodeNodeError(f"Unsupported code language: {code_language}") - def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 9a3528866c..8b73b89e2f 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,12 +1,19 @@ +from enum import StrEnum from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from core.helper.code_executor.code_executor import CodeLanguage from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector from core.workflow.variables.types import SegmentType + +class CodeLanguage(StrEnum): + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" + + _ALLOWED_OUTPUT_FROM_CODE = frozenset( [ SegmentType.STRING, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index c4fc5ccc1f..b862cbe89e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -112,7 +112,6 @@ class MockNodeFactory(DifyNodeFactory): graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, code_executor=self._code_executor, - code_providers=self._code_providers, code_limits=self._code_limits, ) elif node_type == NodeType.HTTP_REQUEST: diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 2c2da3c4f9..00c8cb3779 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,7 +1,6 @@ from configs import dify_config -from core.helper.code_executor.code_executor import CodeLanguage from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData from core.workflow.nodes.code.exc import ( CodeNodeError, DepthLimitError, @@ -438,7 +437,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node.init_node_data(data) + node._node_data = node._hydrate_node_data(data) assert node._node_data.title == "Test Node" assert node._node_data.code_language == CodeLanguage.PYTHON3 @@ -454,7 +453,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node.init_node_data(data) + node._node_data = node._hydrate_node_data(data) assert node._node_data.code_language == CodeLanguage.JAVASCRIPT diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py index cfdeb30ab8..28d59c3568 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py @@ -1,8 +1,7 @@ import pytest from pydantic import ValidationError -from core.helper.code_executor.code_executor import CodeLanguage -from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData from core.workflow.variables.types import SegmentType