From 725d6b52a75721ac0aef85c68c34dcf7798521f5 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Fri, 5 Dec 2025 00:22:10 +0800 Subject: [PATCH] feat: start node support json schema (#29053) --- api/core/app/app_config/entities.py | 14 ++ api/core/workflow/nodes/start/start_node.py | 30 +++ api/pyproject.toml | 1 + .../nodes/test_start_node_json_object.py | 227 ++++++++++++++++++ api/uv.lock | 2 + 5 files changed, 274 insertions(+) create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 2aa36ddc49..93f2742599 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal +from jsonschema import Draft7Validator, SchemaError from pydantic import BaseModel, Field, field_validator from core.file import FileTransferMethod, FileType, FileUploadConfig @@ -98,6 +99,7 @@ class VariableEntityType(StrEnum): FILE = "file" FILE_LIST = "file-list" CHECKBOX = "checkbox" + JSON_OBJECT = "json_object" class VariableEntity(BaseModel): @@ -118,6 +120,7 @@ class VariableEntity(BaseModel): allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) + json_schema: dict[str, Any] | None = Field(default=None) @field_validator("description", mode="before") @classmethod @@ -129,6 +132,17 @@ class VariableEntity(BaseModel): def convert_none_options(cls, v: Any) -> Sequence[str]: return v or [] + @field_validator("json_schema") + @classmethod + def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: + if schema is None: + return None + try: + Draft7Validator.check_schema(schema) + except SchemaError as e: + raise ValueError(f"Invalid JSON schema: {e.message}") + return schema + class RagPipelineVariableEntity(VariableEntity): """ diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 6d2938771f..38effa79f7 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,3 +1,8 @@ +from typing import Any + +from jsonschema import Draft7Validator, ValidationError + +from core.app.app_config.entities import VariableEntityType from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult @@ -15,6 +20,7 @@ class StartNode(Node[StartNodeData]): def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + self._validate_and_normalize_json_object_inputs(node_inputs) system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() # TODO: System variables should be directly accessible, no need for special handling @@ -24,3 +30,27 @@ class StartNode(Node[StartNodeData]): outputs = dict(node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) + + def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None: + for variable in self.node_data.variables: + if variable.type != VariableEntityType.JSON_OBJECT: + continue + + key = variable.variable + value = node_inputs.get(key) + + if value is None and variable.required: + raise ValueError(f"{key} is required in input form") + + if not isinstance(value, dict): + raise ValueError(f"{key} must be a JSON object") + + schema = variable.json_schema + if not schema: + continue + + try: + Draft7Validator(schema).validate(value) + except ValidationError as e: + raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") + node_inputs[key] = value diff --git a/api/pyproject.toml b/api/pyproject.toml index d28ba91413..15f7798f99 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -91,6 +91,7 @@ dependencies = [ "weaviate-client==4.17.0", "apscheduler>=3.11.0", "weave>=0.52.16", + "jsonschema>=4.25.1", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py new file mode 100644 index 0000000000..83799c9508 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -0,0 +1,227 @@ +import time + +import pytest +from pydantic import ValidationError as PydanticValidationError + +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.workflow.entities import GraphInitParams +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +def make_start_node(user_inputs, variables): + variable_pool = VariablePool( + system_variables=SystemVariable(), + user_inputs=user_inputs, + conversation_variables=[], + ) + + config = { + "id": "start", + "data": StartNodeData(title="Start", variables=variables).model_dump(), + } + + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + ) + + return StartNode( + id="start", + config=config, + graph_init_params=GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="wf", + graph_config={}, + user_id="u", + user_from="account", + invoke_from="debugger", + call_depth=0, + ), + graph_runtime_state=graph_runtime_state, + ) + + +def test_json_object_valid_schema(): + schema = { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age"], + } + + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + json_schema=schema, + ) + ] + + user_inputs = {"profile": {"age": 20, "name": "Tom"}} + + node = make_start_node(user_inputs, variables) + result = node._run() + + assert result.outputs["profile"] == {"age": 20, "name": "Tom"} + + +def test_json_object_invalid_json_string(): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ] + + # Missing closing brace makes this invalid JSON + user_inputs = {"profile": '{"age": 20, "name": "Tom"'} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match="profile must be a JSON object"): + node._run() + + +@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"]) +def test_json_object_valid_json_but_not_object(value): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ] + + user_inputs = {"profile": value} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match="profile must be a JSON object"): + node._run() + + +def test_json_object_does_not_match_schema(): + schema = { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + json_schema=schema, + ) + ] + + # age is a string, which violates the schema (expects number) + user_inputs = {"profile": {"age": "twenty", "name": "Tom"}} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match=r"JSON object for 'profile' does not match schema:"): + node._run() + + +def test_json_object_missing_required_schema_field(): + schema = { + "type": "object", + "properties": { + "age": {"type": "number"}, + "name": {"type": "string"}, + }, + "required": ["age", "name"], + } + + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + json_schema=schema, + ) + ] + + # Missing required field "name" + user_inputs = {"profile": {"age": 20}} + + node = make_start_node(user_inputs, variables) + + with pytest.raises( + ValueError, match=r"JSON object for 'profile' does not match schema: 'name' is a required property" + ): + node._run() + + +def test_json_object_required_variable_missing_from_inputs(): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ] + + user_inputs = {} + + node = make_start_node(user_inputs, variables) + + with pytest.raises(ValueError, match="profile is required in input form"): + node._run() + + +def test_json_object_invalid_json_schema_string(): + variable = VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + + # Bypass pydantic type validation on assignment to simulate an invalid JSON schema string + variable.json_schema = "{invalid-json-schema" + + variables = [variable] + user_inputs = {"profile": '{"age": 20}'} + + # Invalid json_schema string should be rejected during node data hydration + with pytest.raises(PydanticValidationError): + make_start_node(user_inputs, variables) + + +def test_json_object_optional_variable_not_provided(): + variables = [ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=False, + ) + ] + + user_inputs = {} + + node = make_start_node(user_inputs, variables) + + # Current implementation raises a validation error even when the variable is optional + with pytest.raises(ValueError, match="profile must be a JSON object"): + node._run() diff --git a/api/uv.lock b/api/uv.lock index 13b8f5bdef..e36e3e9b5f 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1371,6 +1371,7 @@ dependencies = [ { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, + { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1566,6 +1567,7 @@ requires-dist = [ { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.41.1" }, + { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "litellm", specifier = "==1.77.1" },