mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 22:28:55 +08:00
Type phase 3 schema contracts
This commit is contained in:
parent
9f0d79b8b0
commit
c4aeaa35d4
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Literal, NotRequired, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -10,11 +10,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class StructuredOutputConfig(TypedDict):
|
||||
schema: Mapping[str, object]
|
||||
name: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
completion_params: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
@ -33,7 +39,7 @@ class VisionConfig(BaseModel):
|
||||
|
||||
@field_validator("configs", mode="before")
|
||||
@classmethod
|
||||
def convert_none_configs(cls, v: Any):
|
||||
def convert_none_configs(cls, v: object):
|
||||
if v is None:
|
||||
return VisionConfigOptions()
|
||||
return v
|
||||
@ -44,7 +50,7 @@ class PromptConfig(BaseModel):
|
||||
|
||||
@field_validator("jinja2_variables", mode="before")
|
||||
@classmethod
|
||||
def convert_none_jinja2_variables(cls, v: Any):
|
||||
def convert_none_jinja2_variables(cls, v: object):
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
@ -67,7 +73,7 @@ class LLMNodeData(BaseNodeData):
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
structured_output: StructuredOutputConfig | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
@ -90,7 +96,7 @@ class LLMNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
def convert_none_prompt_config(cls, v: object):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@ -74,6 +74,7 @@ from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
StructuredOutputConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@ -354,7 +355,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
structured_output: StructuredOutputConfig | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
@ -367,8 +368,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if structured_output_enabled:
|
||||
if structured_output is None:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
structured_output=structured_output,
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
@ -962,27 +965,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
structured_output: StructuredOutputConfig,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Fetch the structured output schema from the node data.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The structured output schema
|
||||
dict[str, object]: The structured output schema
|
||||
"""
|
||||
if not structured_output:
|
||||
schema = structured_output.get("schema")
|
||||
if not schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
||||
if not structured_output_schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
|
||||
try:
|
||||
schema = json.loads(structured_output_schema)
|
||||
if not isinstance(schema, dict):
|
||||
raise LLMNodeError("structured_output_schema must be a JSON object")
|
||||
return schema
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
return dict(schema)
|
||||
|
||||
@staticmethod
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@ -55,7 +55,7 @@ class ParameterConfig(BaseModel):
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, value) -> str:
|
||||
def validate_name(cls, value: object) -> str:
|
||||
if not value:
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
@ -79,6 +79,23 @@ class ParameterConfig(BaseModel):
|
||||
return element_type
|
||||
|
||||
|
||||
class JsonSchemaArrayItems(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
class ParameterJsonSchemaProperty(TypedDict, total=False):
|
||||
description: str
|
||||
type: str
|
||||
items: JsonSchemaArrayItems
|
||||
enum: list[str]
|
||||
|
||||
|
||||
class ParameterJsonSchema(TypedDict):
|
||||
type: Literal["object"]
|
||||
properties: dict[str, ParameterJsonSchemaProperty]
|
||||
required: list[str]
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
Parameter Extractor Node Data.
|
||||
@ -95,19 +112,19 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("reasoning_mode", mode="before")
|
||||
@classmethod
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
def set_reasoning_mode(cls, v: object) -> str:
|
||||
return str(v) if v else "function_call"
|
||||
|
||||
def get_parameter_json_schema(self):
|
||||
def get_parameter_json_schema(self) -> ParameterJsonSchema:
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
|
||||
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
@ -118,7 +135,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
parameter_schema["type"] = parameter.type.value
|
||||
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
@ -70,7 +70,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
def extract_json(text: str) -> str | None:
|
||||
"""
|
||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||
"""
|
||||
@ -392,10 +392,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
# generate tool
|
||||
parameter_schema = node_data.get_parameter_json_schema()
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
description="Extract parameters from the natural language text",
|
||||
parameters=node_data.get_parameter_json_schema(),
|
||||
parameters={
|
||||
"type": parameter_schema["type"],
|
||||
"properties": dict(parameter_schema["properties"]),
|
||||
"required": list(parameter_schema["required"]),
|
||||
},
|
||||
)
|
||||
|
||||
return prompt_messages, [tool]
|
||||
@ -602,19 +607,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result: dict[str, Any] = {}
|
||||
transformed_result: dict[str, object] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
if isinstance(param_value, (bool, int, float, str)):
|
||||
numeric_value: bool | int | float | str = param_value
|
||||
transformed = self._transform_number(numeric_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
@ -661,7 +668,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
@ -672,11 +679,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
return cast(dict[str, object], json.loads(json_str))
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
@ -690,16 +697,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
return cast(dict[str, object], json.loads(json_str))
|
||||
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user