From c4aeaa35d4bd240d101443b43fb9468cb51b4a57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yanli=20=E7=9B=90=E7=B2=92?= Date: Tue, 17 Mar 2026 18:56:22 +0800 Subject: [PATCH] Type phase 3 schema contracts --- api/dify_graph/nodes/llm/entities.py | 18 ++++++---- api/dify_graph/nodes/llm/node.py | 28 +++++++--------- .../nodes/parameter_extractor/entities.py | 33 ++++++++++++++----- .../parameter_extractor_node.py | 33 +++++++++++-------- 4 files changed, 68 insertions(+), 44 deletions(-) diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py index 6ca01a21da..1ccfd3bbe1 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Literal, NotRequired, TypedDict from pydantic import BaseModel, Field, field_validator @@ -10,11 +10,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode from dify_graph.nodes.base.entities import VariableSelector +class StructuredOutputConfig(TypedDict): + schema: Mapping[str, object] + name: NotRequired[str] + description: NotRequired[str] + + class ModelConfig(BaseModel): provider: str name: str mode: LLMMode - completion_params: dict[str, Any] = Field(default_factory=dict) + completion_params: dict[str, object] = Field(default_factory=dict) class ContextConfig(BaseModel): @@ -33,7 +39,7 @@ class VisionConfig(BaseModel): @field_validator("configs", mode="before") @classmethod - def convert_none_configs(cls, v: Any): + def convert_none_configs(cls, v: object): if v is None: return VisionConfigOptions() return v @@ -44,7 +50,7 @@ class PromptConfig(BaseModel): @field_validator("jinja2_variables", mode="before") @classmethod - def convert_none_jinja2_variables(cls, v: Any): + def convert_none_jinja2_variables(cls, v: object): if v is None: return [] return v @@ -67,7 +73,7 @@ class LLMNodeData(BaseNodeData): memory: MemoryConfig | None = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: Mapping[str, Any] | None = None + structured_output: StructuredOutputConfig | None = None # We used 'structured_output_enabled' in the past, but it's not a good name. structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") reasoning_format: Literal["separated", "tagged"] = Field( @@ -90,7 +96,7 @@ class LLMNodeData(BaseNodeData): @field_validator("prompt_config", mode="before") @classmethod - def convert_none_prompt_config(cls, v: Any): + def convert_none_prompt_config(cls, v: object): if v is None: return PromptConfig() return v diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 5ed90ed7e3..cb002e2f6d 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -74,6 +74,7 @@ from .entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, + StructuredOutputConfig, ) from .exc import ( InvalidContextStructureError, @@ -354,7 +355,7 @@ class LLMNode(Node[LLMNodeData]): stop: Sequence[str] | None = None, user_id: str, structured_output_enabled: bool, - structured_output: Mapping[str, Any] | None = None, + structured_output: StructuredOutputConfig | None = None, file_saver: LLMFileSaver, file_outputs: list[File], node_id: str, @@ -367,8 +368,10 @@ class LLMNode(Node[LLMNodeData]): model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) if structured_output_enabled: + if structured_output is None: + raise LLMNodeError("Please provide a valid structured output schema") output_schema = LLMNode.fetch_structured_output_schema( - structured_output=structured_output or {}, + structured_output=structured_output, ) request_start_time = time.perf_counter() @@ -962,27 +965,18 @@ class LLMNode(Node[LLMNodeData]): @staticmethod def fetch_structured_output_schema( *, - structured_output: Mapping[str, Any], - ) -> dict[str, Any]: + structured_output: StructuredOutputConfig, + ) -> dict[str, object]: """ Fetch the structured output schema from the node data. Returns: - dict[str, Any]: The structured output schema + dict[str, object]: The structured output schema """ - if not structured_output: + schema = structured_output.get("schema") + if not schema: raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) - if not structured_output_schema: - raise LLMNodeError("Please provide a valid structured output schema") - - try: - schema = json.loads(structured_output_schema) - if not isinstance(schema, dict): - raise LLMNodeError("structured_output_schema must be a JSON object") - return schema - except json.JSONDecodeError: - raise LLMNodeError("structured_output_schema is not valid JSON format") + return dict(schema) @staticmethod def _save_multimodal_output_and_convert_result_to_markdown( diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py index 2fb042c16c..213c109607 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Literal +from typing import Annotated, Literal, TypedDict from pydantic import ( BaseModel, @@ -55,7 +55,7 @@ class ParameterConfig(BaseModel): @field_validator("name", mode="before") @classmethod - def validate_name(cls, value) -> str: + def validate_name(cls, value: object) -> str: if not value: raise ValueError("Parameter name is required") if value in {"__reason", "__is_success"}: @@ -79,6 +79,23 @@ class ParameterConfig(BaseModel): return element_type +class JsonSchemaArrayItems(TypedDict): + type: str + + +class ParameterJsonSchemaProperty(TypedDict, total=False): + description: str + type: str + items: JsonSchemaArrayItems + enum: list[str] + + +class ParameterJsonSchema(TypedDict): + type: Literal["object"] + properties: dict[str, ParameterJsonSchemaProperty] + required: list[str] + + class ParameterExtractorNodeData(BaseNodeData): """ Parameter Extractor Node Data. @@ -95,19 +112,19 @@ class ParameterExtractorNodeData(BaseNodeData): @field_validator("reasoning_mode", mode="before") @classmethod - def set_reasoning_mode(cls, v) -> str: - return v or "function_call" + def set_reasoning_mode(cls, v: object) -> str: + return str(v) if v else "function_call" - def get_parameter_json_schema(self): + def get_parameter_json_schema(self) -> ParameterJsonSchema: """ Get parameter json schema. :return: parameter json schema """ - parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema: dict[str, Any] = {"description": parameter.description} + parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description} if parameter.type == SegmentType.STRING: parameter_schema["type"] = "string" @@ -118,7 +135,7 @@ class ParameterExtractorNodeData(BaseNodeData): raise AssertionError("element type should not be None.") parameter_schema["items"] = {"type": element_type.value} else: - parameter_schema["type"] = parameter.type + parameter_schema["type"] = parameter.type.value if parameter.options: parameter_schema["enum"] = parameter.options diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 3913a27697..f9b3e9f545 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -70,7 +70,7 @@ if TYPE_CHECKING: from dify_graph.runtime import GraphRuntimeState -def extract_json(text): +def extract_json(text: str) -> str | None: """ From a given JSON started from '{' or '[' extract the complete JSON object. """ @@ -392,10 +392,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) # generate tool + parameter_schema = node_data.get_parameter_json_schema() tool = PromptMessageTool( name=FUNCTION_CALLING_EXTRACTOR_NAME, description="Extract parameters from the natural language text", - parameters=node_data.get_parameter_json_schema(), + parameters={ + "type": parameter_schema["type"], + "properties": dict(parameter_schema["properties"]), + "required": list(parameter_schema["required"]), + }, ) return prompt_messages, [tool] @@ -602,19 +607,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): else: return None - def _transform_result(self, data: ParameterExtractorNodeData, result: dict): + def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]: """ Transform result into standard format. """ - transformed_result: dict[str, Any] = {} + transformed_result: dict[str, object] = {} for parameter in data.parameters: if parameter.name in result: param_value = result[parameter.name] # transform value if parameter.type == SegmentType.NUMBER: - transformed = self._transform_number(param_value) - if transformed is not None: - transformed_result[parameter.name] = transformed + if isinstance(param_value, (bool, int, float, str)): + numeric_value: bool | int | float | str = param_value + transformed = self._transform_number(numeric_value) + if transformed is not None: + transformed_result[parameter.name] = transformed elif parameter.type == SegmentType.BOOLEAN: if isinstance(result[parameter.name], (bool, int)): transformed_result[parameter.name] = bool(result[parameter.name]) @@ -661,7 +668,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return transformed_result - def _extract_complete_json_response(self, result: str) -> dict | None: + def _extract_complete_json_response(self, result: str) -> dict[str, object] | None: """ Extract complete json response. """ @@ -672,11 +679,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): json_str = extract_json(result[idx:]) if json_str: with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) + return cast(dict[str, object], json.loads(json_str)) logger.info("extra error: %s", result) return None - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: + def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None: """ Extract json from tool call. """ @@ -690,16 +697,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): json_str = extract_json(result[idx:]) if json_str: with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) + return cast(dict[str, object], json.loads(json_str)) logger.info("extra error: %s", result) return None - def _generate_default_result(self, data: ParameterExtractorNodeData): + def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]: """ Generate default result. """ - result: dict[str, Any] = {} + result: dict[str, object] = {} for parameter in data.parameters: if parameter.type == "number": result[parameter.name] = 0