Type phase 3 schema contracts

This commit is contained in:
Yanli 盐粒 2026-03-17 18:56:22 +08:00
parent 9f0d79b8b0
commit c4aeaa35d4
4 changed files with 68 additions and 44 deletions

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Literal, NotRequired, TypedDict
from pydantic import BaseModel, Field, field_validator
@ -10,11 +10,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base.entities import VariableSelector
class StructuredOutputConfig(TypedDict):
schema: Mapping[str, object]
name: NotRequired[str]
description: NotRequired[str]
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = Field(default_factory=dict)
completion_params: dict[str, object] = Field(default_factory=dict)
class ContextConfig(BaseModel):
@ -33,7 +39,7 @@ class VisionConfig(BaseModel):
@field_validator("configs", mode="before")
@classmethod
def convert_none_configs(cls, v: Any):
def convert_none_configs(cls, v: object):
if v is None:
return VisionConfigOptions()
return v
@ -44,7 +50,7 @@ class PromptConfig(BaseModel):
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
def convert_none_jinja2_variables(cls, v: object):
if v is None:
return []
return v
@ -67,7 +73,7 @@ class LLMNodeData(BaseNodeData):
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None
structured_output: StructuredOutputConfig | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
@ -90,7 +96,7 @@ class LLMNodeData(BaseNodeData):
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
def convert_none_prompt_config(cls, v: object):
if v is None:
return PromptConfig()
return v

View File

@ -74,6 +74,7 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
StructuredOutputConfig,
)
from .exc import (
InvalidContextStructureError,
@ -354,7 +355,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
structured_output: StructuredOutputConfig | None = None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@ -367,8 +368,10 @@ class LLMNode(Node[LLMNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
if structured_output is None:
raise LLMNodeError("Please provide a valid structured output schema")
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
structured_output=structured_output,
)
request_start_time = time.perf_counter()
@ -962,27 +965,18 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
structured_output: StructuredOutputConfig,
) -> dict[str, object]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
dict[str, object]: The structured output schema
"""
if not structured_output:
schema = structured_output.get("schema")
if not schema:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
return dict(schema)
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal
from typing import Annotated, Literal, TypedDict
from pydantic import (
BaseModel,
@ -55,7 +55,7 @@ class ParameterConfig(BaseModel):
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
def validate_name(cls, value: object) -> str:
if not value:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
@ -79,6 +79,23 @@ class ParameterConfig(BaseModel):
return element_type
class JsonSchemaArrayItems(TypedDict):
type: str
class ParameterJsonSchemaProperty(TypedDict, total=False):
description: str
type: str
items: JsonSchemaArrayItems
enum: list[str]
class ParameterJsonSchema(TypedDict):
type: Literal["object"]
properties: dict[str, ParameterJsonSchemaProperty]
required: list[str]
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
@ -95,19 +112,19 @@ class ParameterExtractorNodeData(BaseNodeData):
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def set_reasoning_mode(cls, v: object) -> str:
return str(v) if v else "function_call"
def get_parameter_json_schema(self):
def get_parameter_json_schema(self) -> ParameterJsonSchema:
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
@ -118,7 +135,7 @@ class ParameterExtractorNodeData(BaseNodeData):
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
parameter_schema["type"] = parameter.type.value
if parameter.options:
parameter_schema["enum"] = parameter.options

View File

@ -70,7 +70,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
def extract_json(text):
def extract_json(text: str) -> str | None:
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
@ -392,10 +392,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
# generate tool
parameter_schema = node_data.get_parameter_json_schema()
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(),
parameters={
"type": parameter_schema["type"],
"properties": dict(parameter_schema["properties"]),
"required": list(parameter_schema["required"]),
},
)
return prompt_messages, [tool]
@ -602,19 +607,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
"""
Transform result into standard format.
"""
transformed_result: dict[str, Any] = {}
transformed_result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
if isinstance(param_value, (bool, int, float, str)):
numeric_value: bool | int | float | str = param_value
transformed = self._transform_number(numeric_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
@ -661,7 +668,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return transformed_result
def _extract_complete_json_response(self, result: str) -> dict | None:
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
"""
Extract complete json response.
"""
@ -672,11 +679,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return cast(dict[str, object], json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
"""
Extract json from tool call.
"""
@ -690,16 +697,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return cast(dict[str, object], json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData):
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
"""
Generate default result.
"""
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0