mirror of https://github.com/langgenius/dify.git
feat: start node support json schema (#29053)
This commit is contained in:
parent
79640a04cc
commit
725d6b52a7
|
|
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from jsonschema import Draft7Validator, SchemaError
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||||
|
|
@ -98,6 +99,7 @@ class VariableEntityType(StrEnum):
|
||||||
FILE = "file"
|
FILE = "file"
|
||||||
FILE_LIST = "file-list"
|
FILE_LIST = "file-list"
|
||||||
CHECKBOX = "checkbox"
|
CHECKBOX = "checkbox"
|
||||||
|
JSON_OBJECT = "json_object"
|
||||||
|
|
||||||
|
|
||||||
class VariableEntity(BaseModel):
|
class VariableEntity(BaseModel):
|
||||||
|
|
@ -118,6 +120,7 @@ class VariableEntity(BaseModel):
|
||||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||||
|
json_schema: dict[str, Any] | None = Field(default=None)
|
||||||
|
|
||||||
@field_validator("description", mode="before")
|
@field_validator("description", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -129,6 +132,17 @@ class VariableEntity(BaseModel):
|
||||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||||
return v or []
|
return v or []
|
||||||
|
|
||||||
|
@field_validator("json_schema")
|
||||||
|
@classmethod
|
||||||
|
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
|
if schema is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
Draft7Validator.check_schema(schema)
|
||||||
|
except SchemaError as e:
|
||||||
|
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
class RagPipelineVariableEntity(VariableEntity):
|
class RagPipelineVariableEntity(VariableEntity):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,8 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from jsonschema import Draft7Validator, ValidationError
|
||||||
|
|
||||||
|
from core.app.app_config.entities import VariableEntityType
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
|
|
@ -15,6 +20,7 @@ class StartNode(Node[StartNodeData]):
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||||
|
self._validate_and_normalize_json_object_inputs(node_inputs)
|
||||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||||
|
|
||||||
# TODO: System variables should be directly accessible, no need for special handling
|
# TODO: System variables should be directly accessible, no need for special handling
|
||||||
|
|
@ -24,3 +30,27 @@ class StartNode(Node[StartNodeData]):
|
||||||
outputs = dict(node_inputs)
|
outputs = dict(node_inputs)
|
||||||
|
|
||||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
|
||||||
|
|
||||||
|
def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None:
|
||||||
|
for variable in self.node_data.variables:
|
||||||
|
if variable.type != VariableEntityType.JSON_OBJECT:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = variable.variable
|
||||||
|
value = node_inputs.get(key)
|
||||||
|
|
||||||
|
if value is None and variable.required:
|
||||||
|
raise ValueError(f"{key} is required in input form")
|
||||||
|
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(f"{key} must be a JSON object")
|
||||||
|
|
||||||
|
schema = variable.json_schema
|
||||||
|
if not schema:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
Draft7Validator(schema).validate(value)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||||
|
node_inputs[key] = value
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,7 @@ dependencies = [
|
||||||
"weaviate-client==4.17.0",
|
"weaviate-client==4.17.0",
|
||||||
"apscheduler>=3.11.0",
|
"apscheduler>=3.11.0",
|
||||||
"weave>=0.52.16",
|
"weave>=0.52.16",
|
||||||
|
"jsonschema>=4.25.1",
|
||||||
]
|
]
|
||||||
# Before adding new dependency, consider place it in
|
# Before adding new dependency, consider place it in
|
||||||
# alphabet order (a-z) and suitable group.
|
# alphabet order (a-z) and suitable group.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,227 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError as PydanticValidationError
|
||||||
|
|
||||||
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.nodes.start.entities import StartNodeData
|
||||||
|
from core.workflow.nodes.start.start_node import StartNode
|
||||||
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
|
def make_start_node(user_inputs, variables):
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables=SystemVariable(),
|
||||||
|
user_inputs=user_inputs,
|
||||||
|
conversation_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"id": "start",
|
||||||
|
"data": StartNodeData(title="Start", variables=variables).model_dump(),
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=time.perf_counter(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return StartNode(
|
||||||
|
id="start",
|
||||||
|
config=config,
|
||||||
|
graph_init_params=GraphInitParams(
|
||||||
|
tenant_id="tenant",
|
||||||
|
app_id="app",
|
||||||
|
workflow_id="wf",
|
||||||
|
graph_config={},
|
||||||
|
user_id="u",
|
||||||
|
user_from="account",
|
||||||
|
invoke_from="debugger",
|
||||||
|
call_depth=0,
|
||||||
|
),
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_object_valid_schema():
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"age": {"type": "number"},
|
||||||
|
"name": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["age"],
|
||||||
|
}
|
||||||
|
|
||||||
|
variables = [
|
||||||
|
VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=True,
|
||||||
|
json_schema=schema,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||||
|
|
||||||
|
node = make_start_node(user_inputs, variables)
|
||||||
|
result = node._run()
|
||||||
|
|
||||||
|
assert result.outputs["profile"] == {"age": 20, "name": "Tom"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_object_invalid_json_string():
|
||||||
|
variables = [
|
||||||
|
VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Missing closing brace makes this invalid JSON
|
||||||
|
user_inputs = {"profile": '{"age": 20, "name": "Tom"'}
|
||||||
|
|
||||||
|
node = make_start_node(user_inputs, variables)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||||
|
node._run()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
|
||||||
|
def test_json_object_valid_json_but_not_object(value):
|
||||||
|
variables = [
|
||||||
|
VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
user_inputs = {"profile": value}
|
||||||
|
|
||||||
|
node = make_start_node(user_inputs, variables)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||||
|
node._run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_object_does_not_match_schema():
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"age": {"type": "number"},
|
||||||
|
"name": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["age", "name"],
|
||||||
|
}
|
||||||
|
|
||||||
|
variables = [
|
||||||
|
VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=True,
|
||||||
|
json_schema=schema,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# age is a string, which violates the schema (expects number)
|
||||||
|
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||||
|
|
||||||
|
node = make_start_node(user_inputs, variables)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"JSON object for 'profile' does not match schema:"):
|
||||||
|
node._run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_object_missing_required_schema_field():
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"age": {"type": "number"},
|
||||||
|
"name": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["age", "name"],
|
||||||
|
}
|
||||||
|
|
||||||
|
variables = [
|
||||||
|
VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=True,
|
||||||
|
json_schema=schema,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Missing required field "name"
|
||||||
|
user_inputs = {"profile": {"age": 20}}
|
||||||
|
|
||||||
|
node = make_start_node(user_inputs, variables)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match=r"JSON object for 'profile' does not match schema: 'name' is a required property"
|
||||||
|
):
|
||||||
|
node._run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_object_required_variable_missing_from_inputs():
|
||||||
|
variables = [
|
||||||
|
VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
user_inputs = {}
|
||||||
|
|
||||||
|
node = make_start_node(user_inputs, variables)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="profile is required in input form"):
|
||||||
|
node._run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_object_invalid_json_schema_string():
|
||||||
|
variable = VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bypass pydantic type validation on assignment to simulate an invalid JSON schema string
|
||||||
|
variable.json_schema = "{invalid-json-schema"
|
||||||
|
|
||||||
|
variables = [variable]
|
||||||
|
user_inputs = {"profile": '{"age": 20}'}
|
||||||
|
|
||||||
|
# Invalid json_schema string should be rejected during node data hydration
|
||||||
|
with pytest.raises(PydanticValidationError):
|
||||||
|
make_start_node(user_inputs, variables)
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_object_optional_variable_not_provided():
|
||||||
|
variables = [
|
||||||
|
VariableEntity(
|
||||||
|
variable="profile",
|
||||||
|
label="profile",
|
||||||
|
type=VariableEntityType.JSON_OBJECT,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
user_inputs = {}
|
||||||
|
|
||||||
|
node = make_start_node(user_inputs, variables)
|
||||||
|
|
||||||
|
# Current implementation raises a validation error even when the variable is optional
|
||||||
|
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||||
|
node._run()
|
||||||
|
|
@ -1371,6 +1371,7 @@ dependencies = [
|
||||||
{ name = "httpx-sse" },
|
{ name = "httpx-sse" },
|
||||||
{ name = "jieba" },
|
{ name = "jieba" },
|
||||||
{ name = "json-repair" },
|
{ name = "json-repair" },
|
||||||
|
{ name = "jsonschema" },
|
||||||
{ name = "langfuse" },
|
{ name = "langfuse" },
|
||||||
{ name = "langsmith" },
|
{ name = "langsmith" },
|
||||||
{ name = "litellm" },
|
{ name = "litellm" },
|
||||||
|
|
@ -1566,6 +1567,7 @@ requires-dist = [
|
||||||
{ name = "httpx-sse", specifier = "~=0.4.0" },
|
{ name = "httpx-sse", specifier = "~=0.4.0" },
|
||||||
{ name = "jieba", specifier = "==0.42.1" },
|
{ name = "jieba", specifier = "==0.42.1" },
|
||||||
{ name = "json-repair", specifier = ">=0.41.1" },
|
{ name = "json-repair", specifier = ">=0.41.1" },
|
||||||
|
{ name = "jsonschema", specifier = ">=4.25.1" },
|
||||||
{ name = "langfuse", specifier = "~=2.51.3" },
|
{ name = "langfuse", specifier = "~=2.51.3" },
|
||||||
{ name = "langsmith", specifier = "~=0.1.77" },
|
{ name = "langsmith", specifier = "~=0.1.77" },
|
||||||
{ name = "litellm", specifier = "==1.77.1" },
|
{ name = "litellm", specifier = "==1.77.1" },
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue