mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 14:14:17 +08:00
vibe: implement file structured output
This commit is contained in:
parent
b6465327c1
commit
b66db183c9
@ -1,188 +1,190 @@
|
||||
"""
|
||||
File reference detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||
2. Convert file ID strings to File objects after LLM returns
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.file import File
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
FILE_REF_FORMAT = "dify-file-ref"
|
||||
FILE_PATH_SCHEMA_TYPE = "file"
|
||||
FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"}
|
||||
FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)."
|
||||
|
||||
|
||||
def is_file_ref_property(schema: dict) -> bool:
|
||||
"""Check if a schema property is a file reference."""
|
||||
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
|
||||
if schema.get("type") == FILE_PATH_SCHEMA_TYPE:
|
||||
return True
|
||||
format_value = schema.get("format")
|
||||
if not isinstance(format_value, str):
|
||||
return False
|
||||
normalized_format = format_value.lower().replace("_", "-")
|
||||
return normalized_format in FILE_PATH_SCHEMA_FORMATS
|
||||
|
||||
|
||||
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""
|
||||
Recursively detect file reference fields in schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema to analyze
|
||||
path: Current path in the schema (used for recursion)
|
||||
|
||||
Returns:
|
||||
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||
"""
|
||||
file_ref_paths: list[str] = []
|
||||
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
file_path_fields: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, Mapping):
|
||||
continue
|
||||
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_ref_property(prop_schema):
|
||||
file_ref_paths.append(current_path)
|
||||
elif isinstance(prop_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||
if is_file_path_property(prop_schema_mapping):
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, Mapping):
|
||||
return file_path_fields
|
||||
items_schema_mapping = cast(Mapping[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_ref_property(items_schema):
|
||||
file_ref_paths.append(array_path)
|
||||
elif isinstance(items_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||
|
||||
return file_ref_paths
|
||||
|
||||
|
||||
def convert_file_refs_in_output(
|
||||
output: Mapping[str, Any],
|
||||
json_schema: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert file ID strings to File objects based on schema.
|
||||
|
||||
Args:
|
||||
output: The structured_output from LLM result
|
||||
json_schema: The original JSON schema (to detect file ref fields)
|
||||
tenant_id: Tenant ID for file lookup
|
||||
|
||||
Returns:
|
||||
Output with file references converted to File objects
|
||||
"""
|
||||
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||
if not file_ref_paths:
|
||||
return dict(output)
|
||||
|
||||
result = _deep_copy_dict(output)
|
||||
|
||||
for path in file_ref_paths:
|
||||
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Deep copy a mapping to a mutable dict."""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
result[key] = _deep_copy_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||
if is_file_path_property(items_schema_mapping):
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
|
||||
|
||||
return file_path_fields
|
||||
|
||||
|
||||
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
result = _deep_copy_value(schema)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("structured_output_schema must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
file_path_fields: list[str] = []
|
||||
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
|
||||
return result_dict, file_path_fields
|
||||
|
||||
|
||||
def convert_sandbox_file_paths_in_output(
|
||||
output: Mapping[str, Any],
|
||||
file_path_fields: Sequence[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
) -> tuple[dict[str, Any], list[File]]:
|
||||
if not file_path_fields:
|
||||
return dict(output), []
|
||||
|
||||
result = _deep_copy_value(output)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("Structured output must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
files: list[File] = []
|
||||
for path in file_path_fields:
|
||||
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
|
||||
|
||||
return result_dict, files
|
||||
|
||||
|
||||
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, dict):
|
||||
continue
|
||||
prop_schema_dict = cast(dict[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_dict):
|
||||
_normalize_file_path_schema(prop_schema_dict)
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, dict):
|
||||
return
|
||||
items_schema_dict = cast(dict[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_dict):
|
||||
_normalize_file_path_schema(items_schema_dict)
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
|
||||
|
||||
|
||||
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
|
||||
schema["type"] = "string"
|
||||
schema.pop("format", None)
|
||||
description = schema.get("description", "")
|
||||
if description:
|
||||
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
|
||||
else:
|
||||
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
|
||||
|
||||
|
||||
def _deep_copy_value(value: Any) -> Any:
|
||||
if isinstance(value, Mapping):
|
||||
mapping = cast(Mapping[str, Any], value)
|
||||
return {key: _deep_copy_value(item) for key, item in mapping.items()}
|
||||
if isinstance(value, list):
|
||||
list_value = cast(list[Any], value)
|
||||
return [_deep_copy_value(item) for item in list_value]
|
||||
return value
|
||||
|
||||
|
||||
def _convert_path_in_place(
|
||||
obj: dict[str, Any],
|
||||
path_parts: list[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
files: list[File],
|
||||
) -> None:
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
# Handle array notation like "files[*]"
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else None
|
||||
target = obj.get(key) if key else obj
|
||||
key = current[:-3] if current != "[*]" else ""
|
||||
target_value = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target, list):
|
||||
if isinstance(target_value, list):
|
||||
target_list = cast(list[Any], target_value)
|
||||
if remaining:
|
||||
# Nested array with remaining path - recurse into each item
|
||||
for item in target:
|
||||
for item in target_list:
|
||||
if isinstance(item, dict):
|
||||
_convert_path_in_place(item, remaining, tenant_id)
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
_convert_path_in_place(item_dict, remaining, file_resolver, files)
|
||||
else:
|
||||
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||
files: list[File] = []
|
||||
for item in target:
|
||||
file = _convert_file_id(item, tenant_id)
|
||||
if file is not None:
|
||||
files.append(file)
|
||||
# Replace the array with ArrayFileSegment
|
||||
resolved_files: list[File] = []
|
||||
for item in target_list:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(item)
|
||||
files.append(file)
|
||||
resolved_files.append(file)
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=files)
|
||||
obj[key] = ArrayFileSegment(value=resolved_files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
# Leaf node - convert the value and wrap in FileSegment
|
||||
if current in obj:
|
||||
file = _convert_file_id(obj[current], tenant_id)
|
||||
if file is not None:
|
||||
obj[current] = FileSegment(value=file)
|
||||
else:
|
||||
obj[current] = None
|
||||
else:
|
||||
# Recurse into nested object
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||
if current not in obj:
|
||||
return
|
||||
value = obj[current]
|
||||
if value is None:
|
||||
obj[current] = None
|
||||
return
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(value)
|
||||
files.append(file)
|
||||
obj[current] = FileSegment(value=file)
|
||||
return
|
||||
|
||||
|
||||
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||
"""
|
||||
Convert a file ID string to a File object.
|
||||
|
||||
Tries multiple file sources in order:
|
||||
1. ToolFile (files generated by tools/workflows)
|
||||
2. UploadFile (files uploaded by users)
|
||||
"""
|
||||
if not isinstance(file_id, str):
|
||||
return None
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Try ToolFile first (files generated by tools/workflows)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try UploadFile (files uploaded by users)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# File not found in any source
|
||||
return None
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, file_resolver, files)
|
||||
|
||||
@ -8,7 +8,7 @@ import json_repair
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.output_parser.file_ref import detect_file_path_fields
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@ -55,12 +55,11 @@ def invoke_llm_with_structured_output(
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping | None = None,
|
||||
model_parameters: Mapping[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
"""
|
||||
Invoke large language model with structured output.
|
||||
@ -78,14 +77,12 @@ def invoke_llm_with_structured_output(
|
||||
:param stop: stop words
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
:return: response with structured output
|
||||
"""
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {})
|
||||
|
||||
if detect_file_path_fields(json_schema):
|
||||
raise OutputParserError("Structured output file paths are only supported in sandbox mode.")
|
||||
|
||||
# Determine structured output strategy
|
||||
|
||||
@ -122,14 +119,6 @@ def invoke_llm_with_structured_output(
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
@ -147,12 +136,11 @@ def invoke_llm_with_pydantic_model(
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping | None = None,
|
||||
model_parameters: Mapping[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Invoke large language model with a Pydantic output model.
|
||||
@ -160,11 +148,8 @@ def invoke_llm_with_pydantic_model(
|
||||
This helper generates a JSON schema from the Pydantic model, invokes the
|
||||
structured-output LLM path, and validates the result.
|
||||
|
||||
The stream parameter controls the underlying LLM invocation mode:
|
||||
- stream=True (default): Uses streaming LLM call, consumes the generator internally
|
||||
- stream=False: Uses non-streaming LLM call
|
||||
|
||||
In both cases, the function returns the validated Pydantic model directly.
|
||||
The helper performs a non-streaming invocation and returns the validated
|
||||
Pydantic model directly.
|
||||
"""
|
||||
json_schema = _schema_from_pydantic(output_model)
|
||||
|
||||
@ -179,7 +164,6 @@ def invoke_llm_with_pydantic_model(
|
||||
stop=stop,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
structured_output = result.structured_output
|
||||
@ -236,25 +220,27 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
|
||||
return _parse_structured_output(content)
|
||||
|
||||
|
||||
def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
|
||||
def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]:
|
||||
"""Parse JSON from tool call arguments."""
|
||||
if not arguments:
|
||||
raise OutputParserError("Tool call arguments is empty")
|
||||
|
||||
try:
|
||||
parsed = json.loads(arguments)
|
||||
if not isinstance(parsed, dict):
|
||||
parsed_any = json.loads(arguments)
|
||||
if not isinstance(parsed_any, dict):
|
||||
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
|
||||
parsed = cast(dict[str, Any], parsed_any)
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
# Try to repair malformed JSON
|
||||
repaired = json_repair.loads(arguments)
|
||||
if not isinstance(repaired, dict):
|
||||
repaired_any = json_repair.loads(arguments)
|
||||
if not isinstance(repaired_any, dict):
|
||||
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
|
||||
repaired: dict[str, Any] = repaired_any
|
||||
return repaired
|
||||
|
||||
|
||||
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||
def get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||
"""Get default empty value for a JSON schema type."""
|
||||
# Handle array of types (e.g., ["string", "null"])
|
||||
if isinstance(type_name, list):
|
||||
@ -311,7 +297,7 @@ def fill_defaults_from_schema(
|
||||
# Create empty object and recursively fill its required fields
|
||||
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
|
||||
else:
|
||||
result[prop_name] = _get_default_value_for_type(prop_type)
|
||||
result[prop_name] = get_default_value_for_type(prop_type)
|
||||
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
|
||||
# Field exists and is an object, recursively fill nested required fields
|
||||
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
|
||||
@ -322,10 +308,10 @@ def fill_defaults_from_schema(
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
structured_output_schema: Mapping[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
rules: list[ParameterRule],
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
@ -347,7 +333,7 @@ def _handle_native_json_schema(
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list):
|
||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
@ -363,7 +349,7 @@ def _set_response_format(model_parameters: dict, rules: list):
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any]
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
@ -400,28 +386,27 @@ def _handle_prompt_based_schema(
|
||||
return updated_prompt
|
||||
|
||||
|
||||
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
structured_output: Mapping[str, Any] = {}
|
||||
parsed: Mapping[str, Any] = {}
|
||||
def _parse_structured_output(result_text: str) -> dict[str, Any]:
|
||||
try:
|
||||
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
parsed = TypeAdapter(dict[str, Any]).validate_json(result_text)
|
||||
return parsed
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
temp_parsed: Any = json_repair.loads(result_text)
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed_list = cast(list[Any], temp_parsed)
|
||||
dict_items: list[dict[str, Any]] = []
|
||||
for item in temp_parsed_list:
|
||||
if isinstance(item, dict):
|
||||
dict_items.append(cast(dict[str, Any], item))
|
||||
temp_parsed = dict_items[0] if dict_items else {}
|
||||
if not isinstance(temp_parsed, dict):
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
temp_parsed_dict = cast(dict[str, Any], temp_parsed)
|
||||
return temp_parsed_dict
|
||||
|
||||
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
@ -433,54 +418,49 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = dict(deepcopy(schema))
|
||||
processed_schema = deepcopy(schema)
|
||||
processed_schema_dict = dict(processed_schema)
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
convert_boolean_to_string(processed_schema_dict)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in model_schema.model:
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
remove_additional_properties(processed_schema_dict)
|
||||
return processed_schema_dict
|
||||
if SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema_dict
|
||||
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema_dict, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict):
|
||||
def remove_additional_properties(schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(value)
|
||||
remove_additional_properties(cast(dict[str, Any], value))
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
for item in cast(list[Any], value):
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(item)
|
||||
remove_additional_properties(cast(dict[str, Any], item))
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict):
|
||||
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
@ -488,8 +468,8 @@ def convert_boolean_to_string(schema: dict):
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(value)
|
||||
convert_boolean_to_string(cast(dict[str, Any], value))
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
for item in cast(list[Any], value):
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(item)
|
||||
convert_boolean_to_string(cast(dict[str, Any], item))
|
||||
|
||||
@ -14,7 +14,8 @@ from core.skill.entities import ToolAccessPolicy
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
||||
from core.virtual_environment.__base.helpers import execute, pipeline
|
||||
|
||||
from ..bash.dify_cli import DifyCliConfig
|
||||
from ..entities import DifyCli
|
||||
@ -119,21 +120,6 @@ class SandboxBashSession:
|
||||
return self._bash_tool
|
||||
|
||||
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
|
||||
"""
|
||||
Collect files from sandbox output directory and save them as ToolFiles.
|
||||
|
||||
Scans the specified output directory in sandbox, downloads each file,
|
||||
saves it as a ToolFile, and returns a list of File objects. The File
|
||||
objects will have valid tool_file_id that can be referenced by subsequent
|
||||
nodes via structured output.
|
||||
|
||||
Args:
|
||||
output_dir: Directory path in sandbox to scan for output files.
|
||||
Defaults to "output" (relative to workspace).
|
||||
|
||||
Returns:
|
||||
List of File objects representing the collected files.
|
||||
"""
|
||||
vm = self._sandbox.vm
|
||||
collected_files: list[File] = []
|
||||
|
||||
@ -144,8 +130,6 @@ class SandboxBashSession:
|
||||
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
|
||||
return collected_files
|
||||
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
for file_state in file_states:
|
||||
# Skip files that are too large
|
||||
if file_state.size > MAX_OUTPUT_FILE_SIZE:
|
||||
@ -162,47 +146,14 @@ class SandboxBashSession:
|
||||
file_content = vm.download_file(file_state.path)
|
||||
file_binary = file_content.getvalue()
|
||||
|
||||
# Determine mime type from extension
|
||||
filename = os.path.basename(file_state.path)
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Save as ToolFile
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mime_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# Determine file type from mime type
|
||||
file_type = _get_file_type_from_mime(mime_type)
|
||||
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
# Create File object with tool_file_id as related_id
|
||||
file_obj = File(
|
||||
id=tool_file.id, # Use tool_file_id as the File id for easy reference
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(file_binary),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
file_obj = self._create_tool_file(filename=filename, file_binary=file_binary)
|
||||
collected_files.append(file_obj)
|
||||
|
||||
logger.info(
|
||||
"Collected sandbox output file: %s -> tool_file_id=%s",
|
||||
file_state.path,
|
||||
tool_file.id,
|
||||
file_obj.id,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
@ -216,6 +167,85 @@ class SandboxBashSession:
|
||||
)
|
||||
return collected_files
|
||||
|
||||
def download_file(self, path: str) -> File:
|
||||
path_kind = self._detect_path_kind(path)
|
||||
if path_kind == "dir":
|
||||
raise ValueError("Directory outputs are not supported")
|
||||
if path_kind != "file":
|
||||
raise ValueError(f"Sandbox file not found: {path}")
|
||||
|
||||
file_content = self._sandbox.vm.download_file(path)
|
||||
file_binary = file_content.getvalue()
|
||||
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
|
||||
raise ValueError(f"Sandbox file exceeds size limit: {path}")
|
||||
|
||||
filename = os.path.basename(path) or "file"
|
||||
return self._create_tool_file(filename=filename, file_binary=file_binary)
|
||||
|
||||
def _detect_path_kind(self, path: str) -> str:
|
||||
script = r"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
p = sys.argv[1]
|
||||
if os.path.isdir(p):
|
||||
print("dir")
|
||||
raise SystemExit(0)
|
||||
if os.path.isfile(p):
|
||||
print("file")
|
||||
raise SystemExit(0)
|
||||
print("none")
|
||||
raise SystemExit(2)
|
||||
"""
|
||||
try:
|
||||
result = execute(
|
||||
self._sandbox.vm,
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"',
|
||||
script,
|
||||
path,
|
||||
],
|
||||
timeout=10,
|
||||
error_message="Failed to inspect sandbox path",
|
||||
)
|
||||
except CommandExecutionError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
return result.stdout.decode("utf-8", errors="replace").strip()
|
||||
|
||||
def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
tool_file = ToolFileManager().create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mime_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
file_type = _get_file_type_from_mime(mime_type)
|
||||
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
id=tool_file.id,
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(file_binary),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_file_type_from_mime(mime_type: str) -> FileType:
|
||||
"""Determine FileType from mime type."""
|
||||
|
||||
@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str:
|
||||
"""
|
||||
Build a text description of generated files for inclusion in context.
|
||||
|
||||
The description includes file_id which can be used by subsequent nodes
|
||||
to reference the files via structured output.
|
||||
The description includes file_id for context; structured output file paths
|
||||
are only supported in sandbox mode.
|
||||
"""
|
||||
if not files:
|
||||
return ""
|
||||
|
||||
@ -21,7 +21,11 @@ from core.app_assets.constants import AppAssetsAttrs
|
||||
from core.file import FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
adapt_schema_for_sandbox_file_paths,
|
||||
convert_sandbox_file_paths_in_output,
|
||||
detect_file_path_fields,
|
||||
)
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -297,13 +301,20 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
structured_output_schema: Mapping[str, Any] | None
|
||||
structured_output_file_paths: list[str] = []
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not self.node_data.structured_output:
|
||||
raise ValueError("structured_output_enabled is True but structured_output is not set")
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self.node_data.structured_output
|
||||
)
|
||||
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
|
||||
if self.node_data.computer_use:
|
||||
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
|
||||
raw_schema
|
||||
)
|
||||
else:
|
||||
if detect_file_path_fields(raw_schema):
|
||||
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
|
||||
structured_output_schema = raw_schema
|
||||
else:
|
||||
structured_output_schema = None
|
||||
|
||||
@ -319,8 +330,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=stop,
|
||||
variable_pool=variable_pool,
|
||||
tool_dependencies=tool_dependencies,
|
||||
structured_output_schema=structured_output_schema
|
||||
structured_output_schema=structured_output_schema,
|
||||
structured_output_file_paths=structured_output_file_paths,
|
||||
)
|
||||
|
||||
elif self.tool_call_enabled:
|
||||
generator = self._invoke_llm_with_tools(
|
||||
model_instance=model_instance,
|
||||
@ -330,7 +343,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
structured_output_schema=structured_output_schema
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
else:
|
||||
# Use traditional LLM invocation
|
||||
@ -532,8 +545,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
@ -1880,7 +1893,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
node_inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
structured_output_schema: Mapping[str, Any] | None
|
||||
structured_output_schema: Mapping[str, Any] | None,
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Invoke LLM with tools support (from Agent V2).
|
||||
|
||||
@ -1906,14 +1919,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 10,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
structured_output_schema=structured_output_schema
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
# Run strategy
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or [])
|
||||
stop=list(stop or []),
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
@ -1927,10 +1940,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop: Sequence[str] | None,
|
||||
variable_pool: VariablePool,
|
||||
tool_dependencies: ToolDependencies | None,
|
||||
structured_output_schema: Mapping[str, Any] | None
|
||||
structured_output_schema: Mapping[str, Any] | None,
|
||||
structured_output_file_paths: Sequence[str] | None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]:
|
||||
result: LLMGenerationData | None = None
|
||||
sandbox_output_files: list[File] = []
|
||||
structured_output_files: list[File] = []
|
||||
|
||||
# FIXME(Mairuis): Async processing for bash session.
|
||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
||||
@ -1948,17 +1963,31 @@ class LLMNode(Node[LLMNodeData]):
|
||||
max_iterations=self._node_data.max_iterations or 100,
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
structured_output_schema=structured_output_schema
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or [])
|
||||
stop=list(stop or []),
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
|
||||
if result and result.structured_output and structured_output_file_paths:
|
||||
structured_output_payload = result.structured_output.structured_output or {}
|
||||
try:
|
||||
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
|
||||
output=structured_output_payload,
|
||||
file_path_fields=structured_output_file_paths,
|
||||
file_resolver=session.download_file,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise LLMNodeError(str(exc)) from exc
|
||||
result = result.model_copy(
|
||||
update={"structured_output": LLMStructuredOutput(structured_output=converted_output)}
|
||||
)
|
||||
|
||||
# Collect output files from sandbox before session ends
|
||||
# Files are saved as ToolFiles with valid tool_file_id for later reference
|
||||
sandbox_output_files = session.collect_output_files()
|
||||
@ -1971,7 +2000,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
yield structured_output
|
||||
|
||||
# Merge sandbox output files into result
|
||||
if sandbox_output_files:
|
||||
if sandbox_output_files or structured_output_files:
|
||||
result = LLMGenerationData(
|
||||
text=result.text,
|
||||
reasoning_contents=result.reasoning_contents,
|
||||
@ -1979,7 +2008,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sequence=result.sequence,
|
||||
usage=result.usage,
|
||||
finish_reason=result.finish_reason,
|
||||
files=result.files + sandbox_output_files,
|
||||
files=result.files + sandbox_output_files + structured_output_files,
|
||||
trace=result.trace,
|
||||
)
|
||||
|
||||
@ -2056,9 +2085,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
tool_instances.extend(
|
||||
build_agent_output_tools(tenant_id=self.tenant_id, invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema)
|
||||
build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_instances
|
||||
@ -2564,15 +2596,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
raise ValueError("No agent result found in tool outputs")
|
||||
output_payload = agent_result.output
|
||||
if isinstance(output_payload, dict):
|
||||
state.aggregate.structured_output = LLMStructuredOutput(
|
||||
structured_output=convert_file_refs_in_output(
|
||||
output=output_payload,
|
||||
json_schema=LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
),
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload)
|
||||
state.aggregate.text = json.dumps(output_payload)
|
||||
elif isinstance(output_payload, str):
|
||||
state.aggregate.text = output_payload
|
||||
|
||||
5
api/tests/fixtures/file output schema.yml
vendored
5
api/tests/fixtures/file output schema.yml
vendored
@ -126,9 +126,8 @@ workflow:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
image:
|
||||
description: File ID (UUID) of the selected image
|
||||
format: dify-file-ref
|
||||
type: string
|
||||
description: Sandbox file path of the selected image
|
||||
type: file
|
||||
required:
|
||||
- image
|
||||
type: object
|
||||
|
||||
@ -1,269 +1,93 @@
|
||||
"""
|
||||
Unit tests for file reference detection and conversion.
|
||||
Unit tests for sandbox file path detection and conversion.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
FILE_REF_FORMAT,
|
||||
convert_file_refs_in_output,
|
||||
detect_file_ref_fields,
|
||||
is_file_ref_property,
|
||||
FILE_PATH_DESCRIPTION_SUFFIX,
|
||||
adapt_schema_for_sandbox_file_paths,
|
||||
convert_sandbox_file_paths_in_output,
|
||||
detect_file_path_fields,
|
||||
is_file_path_property,
|
||||
)
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestIsFileRefProperty:
|
||||
"""Tests for is_file_ref_property function."""
|
||||
|
||||
def test_valid_file_ref(self):
|
||||
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is True
|
||||
|
||||
def test_invalid_type(self):
|
||||
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_missing_format(self):
|
||||
schema = {"type": "string"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_wrong_format(self):
|
||||
schema = {"type": "string", "format": "uuid"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
def _build_file(file_id: str) -> File:
|
||||
return File(
|
||||
id=file_id,
|
||||
tenant_id="tenant_123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename="test.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=128,
|
||||
related_id=file_id,
|
||||
storage_key="sandbox/path",
|
||||
)
|
||||
|
||||
|
||||
class TestDetectFileRefFields:
|
||||
"""Tests for detect_file_ref_fields function."""
|
||||
class TestFilePathSchema:
|
||||
def test_is_file_path_property(self):
|
||||
assert is_file_path_property({"type": "file"}) is True
|
||||
assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True
|
||||
assert is_file_path_property({"type": "string"}) is False
|
||||
|
||||
def test_simple_file_ref(self):
|
||||
def test_detect_file_path_fields(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"image": {"type": "string", "format": "dify-file-ref"},
|
||||
"files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}},
|
||||
"meta": {"type": "object", "properties": {"doc": {"type": "file"}}},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["image"]
|
||||
assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"}
|
||||
|
||||
def test_multiple_file_refs(self):
|
||||
def test_adapt_schema_for_sandbox_file_paths(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"image": {"type": "string", "format": "dify-file-ref"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "document"}
|
||||
adapted, fields = adapt_schema_for_sandbox_file_paths(schema)
|
||||
|
||||
def test_array_of_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["files[*]"]
|
||||
|
||||
def test_nested_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["data.image"]
|
||||
|
||||
def test_no_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_empty_schema(self):
|
||||
schema = {}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_mixed_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "documents[*]"}
|
||||
assert set(fields) == {"image"}
|
||||
adapted_image = adapted["properties"]["image"]
|
||||
assert adapted_image["type"] == "string"
|
||||
assert "format" not in adapted_image
|
||||
assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"]
|
||||
|
||||
|
||||
class TestConvertFileRefsInOutput:
|
||||
"""Tests for convert_file_refs_in_output function."""
|
||||
class TestConvertSandboxFilePaths:
|
||||
def test_convert_sandbox_file_paths(self):
|
||||
output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create a mock File object with all required attributes."""
|
||||
file = MagicMock(spec=File)
|
||||
file.type = FileType.IMAGE
|
||||
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
file.related_id = "test-related-id"
|
||||
file.remote_url = None
|
||||
file.tenant_id = "tenant_123"
|
||||
file.id = None
|
||||
file.filename = "test.png"
|
||||
file.extension = ".png"
|
||||
file.mime_type = "image/png"
|
||||
file.size = 1024
|
||||
file.dify_model_identity = "__dify__file__"
|
||||
return file
|
||||
def resolver(path: str) -> File:
|
||||
return _build_file(path)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_build_from_mapping(self, mock_file):
|
||||
"""Mock the build_from_mapping function."""
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.return_value = mock_file
|
||||
yield mock
|
||||
converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver)
|
||||
|
||||
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
assert isinstance(converted["image"], FileSegment)
|
||||
assert isinstance(converted["files"], ArrayFileSegment)
|
||||
assert converted["name"] == "demo"
|
||||
assert [file.id for file in files] == ["a.png", "b.png", "c.png"]
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
def test_invalid_path_value_raises(self):
|
||||
def resolver(path: str) -> File:
|
||||
return _build_file(path)
|
||||
|
||||
# Result should be wrapped in FileSegment
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
mock_build_from_mapping.assert_called_once_with(
|
||||
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||
tenant_id="tenant_123",
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver)
|
||||
|
||||
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
output = {"files": [file_id1, file_id2]}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
def test_no_file_paths_returns_output(self):
|
||||
output = {"name": "demo"}
|
||||
converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file)
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in ArrayFileSegment
|
||||
assert isinstance(result["files"], ArrayFileSegment)
|
||||
assert list(result["files"].value) == [mock_file, mock_file]
|
||||
assert mock_build_from_mapping.call_count == 2
|
||||
|
||||
def test_no_conversion_without_file_refs(self):
|
||||
output = {"name": "test", "count": 5}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result == {"name": "test", "count": 5}
|
||||
|
||||
def test_invalid_uuid_returns_none(self):
|
||||
output = {"image": "not-a-valid-uuid"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_file_not_found_returns_none(self):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.side_effect = ValueError("File not found")
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"query": "search term", "image": file_id, "count": 10}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["query"] == "search term"
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
assert result["count"] == 10
|
||||
|
||||
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
original = {"image": file_id}
|
||||
output = dict(original)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Original should still contain the string ID
|
||||
assert original["image"] == file_id
|
||||
assert converted == output
|
||||
assert files == []
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from decimal import Decimal
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -6,16 +8,13 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
_get_default_value_for_type,
|
||||
get_default_value_for_type,
|
||||
fill_defaults_from_schema,
|
||||
invoke_llm_with_pydantic_model,
|
||||
invoke_llm_with_structured_output,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
@ -25,7 +24,30 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
|
||||
|
||||
class StructuredOutputTestCase(TypedDict):
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
support_structure_output: bool
|
||||
stream: bool
|
||||
json_schema: Mapping[str, Any]
|
||||
expected_llm_response: LLMResult
|
||||
expected_result_type: type[LLMResultWithStructuredOutput] | None
|
||||
should_raise: bool
|
||||
expected_error: NotRequired[type[OutputParserError]]
|
||||
parameter_rules: NotRequired[list[ParameterRule]]
|
||||
|
||||
|
||||
SchemaData = dict[str, Any]
|
||||
|
||||
|
||||
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
||||
@ -71,7 +93,7 @@ def get_model_instance() -> MagicMock:
|
||||
def test_structured_output_parser():
|
||||
"""Test cases for invoke_llm_with_structured_output function"""
|
||||
|
||||
testcases = [
|
||||
testcases: list[StructuredOutputTestCase] = [
|
||||
# Test case 1: Model with native structured output support, non-streaming
|
||||
{
|
||||
"name": "native_structured_output_non_streaming",
|
||||
@ -88,39 +110,6 @@ def test_structured_output_parser():
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 2: Model with native structured output support, streaming
|
||||
{
|
||||
"name": "native_structured_output_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 3: Model without native structured output support, non-streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_non_streaming",
|
||||
@ -137,78 +126,24 @@ def test_structured_output_parser():
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 4: Model without native structured output support, streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 5: Streaming with list content
|
||||
{
|
||||
"name": "streaming_with_list_content",
|
||||
"name": "non_streaming_with_list_content",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||
@ -253,7 +188,13 @@ def test_structured_output_parser():
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["json_schema"], required=False),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(en_US="response_format"),
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
options=["json_schema"],
|
||||
),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
@ -272,7 +213,13 @@ def test_structured_output_parser():
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
MagicMock(name="response_format", options=["JSON"], required=False),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(en_US="response_format"),
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
options=["JSON"],
|
||||
),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
@ -285,89 +232,72 @@ def test_structured_output_parser():
|
||||
]
|
||||
|
||||
for case in testcases:
|
||||
# Setup model entity
|
||||
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||
provider = case["provider"]
|
||||
model_name = case["model_name"]
|
||||
support_structure_output = case["support_structure_output"]
|
||||
json_schema = case["json_schema"]
|
||||
stream = case["stream"]
|
||||
|
||||
# Add parameter rules if specified
|
||||
if "parameter_rules" in case:
|
||||
model_schema.parameter_rules = case["parameter_rules"]
|
||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
||||
|
||||
parameter_rules = case.get("parameter_rules")
|
||||
if parameter_rules is not None:
|
||||
model_schema.parameter_rules = parameter_rules
|
||||
|
||||
# Setup model instance
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
||||
|
||||
# Setup prompt messages
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Generate a response according to the schema."),
|
||||
]
|
||||
|
||||
if case["should_raise"]:
|
||||
# Test error cases
|
||||
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||
if case["stream"]:
|
||||
expected_error = case.get("expected_error", OutputParserError)
|
||||
with pytest.raises(expected_error): # noqa: PT012
|
||||
if stream:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
json_schema=json_schema,
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
else:
|
||||
invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
json_schema=json_schema,
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Configure json_repair mock for cases that need it
|
||||
if case["name"] == "json_repair_scenario":
|
||||
mock_json_repair.return_value = {"name": "test"}
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=case["provider"],
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=case["json_schema"],
|
||||
json_schema=json_schema,
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
if case["expected_result_type"] == "generator":
|
||||
# Test streaming results
|
||||
assert hasattr(result, "__iter__")
|
||||
chunks = list(result)
|
||||
assert len(chunks) > 0
|
||||
expected_result_type = case["expected_result_type"]
|
||||
assert expected_result_type is not None
|
||||
assert isinstance(result, expected_result_type)
|
||||
assert result.model == model_name
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||
for chunk in chunks[:-1]: # All except last
|
||||
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert chunk.model == case["model_name"]
|
||||
|
||||
# Last chunk should have structured output
|
||||
last_chunk = chunks[-1]
|
||||
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert last_chunk.structured_output is not None
|
||||
assert isinstance(last_chunk.structured_output, dict)
|
||||
else:
|
||||
# Test non-streaming results
|
||||
assert isinstance(result, case["expected_result_type"])
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify model_instance.invoke_llm was called with correct parameters
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["stream"] == stream
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
@ -376,45 +306,32 @@ def test_structured_output_parser():
|
||||
def test_parse_structured_output_edge_cases():
|
||||
"""Test edge cases for structured output parsing"""
|
||||
|
||||
# Test case with list that contains dict (reasoning model scenario)
|
||||
testcase_list_with_dict = {
|
||||
"name": "list_with_dict_parsing",
|
||||
"provider": "deepseek",
|
||||
"model_name": "deepseek-r1",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
}
|
||||
|
||||
# Setup for list parsing test
|
||||
model_schema = get_model_entity(
|
||||
testcase_list_with_dict["provider"],
|
||||
testcase_list_with_dict["model_name"],
|
||||
testcase_list_with_dict["support_structure_output"],
|
||||
provider = "deepseek"
|
||||
model_name = "deepseek-r1"
|
||||
support_structure_output = False
|
||||
json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}}
|
||||
expected_llm_response = LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
)
|
||||
|
||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
|
||||
model_instance.invoke_llm.return_value = expected_llm_response
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Mock json_repair to return a list with dict
|
||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=testcase_list_with_dict["provider"],
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=testcase_list_with_dict["json_schema"],
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
@ -424,18 +341,16 @@ def test_parse_structured_output_edge_cases():
|
||||
def test_model_specific_schema_preparation():
|
||||
"""Test schema preparation for different model types"""
|
||||
|
||||
# Test Gemini model
|
||||
gemini_case = {
|
||||
"provider": "google",
|
||||
"model_name": "gemini-pro",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
|
||||
provider = "google"
|
||||
model_name = "gemini-pro"
|
||||
support_structure_output = True
|
||||
json_schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {"result": {"type": "boolean"}},
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
model_schema = get_model_entity(
|
||||
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
|
||||
)
|
||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
@ -447,11 +362,11 @@ def test_model_specific_schema_preparation():
|
||||
prompt_messages = [UserPromptMessage(content="Test")]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=gemini_case["provider"],
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=gemini_case["json_schema"],
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
@ -493,40 +408,26 @@ def test_structured_output_with_pydantic_model_non_streaming():
|
||||
assert result.name == "test"
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_streaming():
|
||||
def test_structured_output_with_pydantic_model_list_content():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
|
||||
def mock_streaming_response():
|
||||
yield LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=2),
|
||||
),
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
),
|
||||
)
|
||||
|
||||
model_instance.invoke_llm.return_value = mock_streaming_response()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"name":'),
|
||||
TextPromptMessageContent(data=' "test"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
)
|
||||
|
||||
result = invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
|
||||
output_model=ExampleOutput
|
||||
output_model=ExampleOutput,
|
||||
)
|
||||
|
||||
assert isinstance(result, ExampleOutput)
|
||||
@ -548,51 +449,51 @@ def test_structured_output_with_pydantic_model_validation_error():
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
output_model=ExampleOutput
|
||||
output_model=ExampleOutput,
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultValueForType:
|
||||
"""Test cases for _get_default_value_for_type function"""
|
||||
"""Test cases for get_default_value_for_type function"""
|
||||
|
||||
def test_string_type(self):
|
||||
assert _get_default_value_for_type("string") == ""
|
||||
assert get_default_value_for_type("string") == ""
|
||||
|
||||
def test_object_type(self):
|
||||
assert _get_default_value_for_type("object") == {}
|
||||
assert get_default_value_for_type("object") == {}
|
||||
|
||||
def test_array_type(self):
|
||||
assert _get_default_value_for_type("array") == []
|
||||
assert get_default_value_for_type("array") == []
|
||||
|
||||
def test_number_type(self):
|
||||
assert _get_default_value_for_type("number") == 0
|
||||
assert get_default_value_for_type("number") == 0
|
||||
|
||||
def test_integer_type(self):
|
||||
assert _get_default_value_for_type("integer") == 0
|
||||
assert get_default_value_for_type("integer") == 0
|
||||
|
||||
def test_boolean_type(self):
|
||||
assert _get_default_value_for_type("boolean") is False
|
||||
assert get_default_value_for_type("boolean") is False
|
||||
|
||||
def test_null_type(self):
|
||||
assert _get_default_value_for_type("null") is None
|
||||
assert get_default_value_for_type("null") is None
|
||||
|
||||
def test_none_type(self):
|
||||
assert _get_default_value_for_type(None) is None
|
||||
assert get_default_value_for_type(None) is None
|
||||
|
||||
def test_unknown_type(self):
|
||||
assert _get_default_value_for_type("unknown") is None
|
||||
assert get_default_value_for_type("unknown") is None
|
||||
|
||||
def test_union_type_string_null(self):
|
||||
# ["string", "null"] should return "" (first non-null type)
|
||||
assert _get_default_value_for_type(["string", "null"]) == ""
|
||||
assert get_default_value_for_type(["string", "null"]) == ""
|
||||
|
||||
def test_union_type_null_first(self):
|
||||
# ["null", "integer"] should return 0 (first non-null type)
|
||||
assert _get_default_value_for_type(["null", "integer"]) == 0
|
||||
assert get_default_value_for_type(["null", "integer"]) == 0
|
||||
|
||||
def test_union_type_only_null(self):
|
||||
# ["null"] should return None
|
||||
assert _get_default_value_for_type(["null"]) is None
|
||||
assert get_default_value_for_type(["null"]) is None
|
||||
|
||||
|
||||
class TestFillDefaultsFromSchema:
|
||||
@ -600,7 +501,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_simple_required_fields(self):
|
||||
"""Test filling simple required fields"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
@ -609,7 +510,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
output = {"name": "Alice"}
|
||||
output: SchemaData = {"name": "Alice"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -619,7 +520,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_non_required_fields_not_filled(self):
|
||||
"""Test that non-required fields are not filled"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
@ -627,7 +528,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["required_field"],
|
||||
}
|
||||
output = {}
|
||||
output: SchemaData = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -636,7 +537,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_nested_object_required_fields(self):
|
||||
"""Test filling nested object required fields"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
@ -659,7 +560,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["user"],
|
||||
}
|
||||
output = {
|
||||
output: SchemaData = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"address": {
|
||||
@ -684,7 +585,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_missing_nested_object_created(self):
|
||||
"""Test that missing required nested objects are created"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
@ -698,7 +599,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["metadata"],
|
||||
}
|
||||
output = {}
|
||||
output: SchemaData = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -710,7 +611,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_all_types_default_values(self):
|
||||
"""Test default values for all types"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str_field": {"type": "string"},
|
||||
@ -722,7 +623,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
|
||||
}
|
||||
output = {}
|
||||
output: SchemaData = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -737,7 +638,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_existing_values_preserved(self):
|
||||
"""Test that existing values are not overwritten"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
@ -745,7 +646,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["name", "count"],
|
||||
}
|
||||
output = {"name": "Bob", "count": 42}
|
||||
output: SchemaData = {"name": "Bob", "count": 42}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -753,7 +654,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Test complex nested structure with multiple levels"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
@ -789,7 +690,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["user", "tags", "metadata", "is_active"],
|
||||
}
|
||||
output = {
|
||||
output: SchemaData = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"age": 25,
|
||||
@ -829,8 +730,8 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_empty_schema(self):
|
||||
"""Test with empty schema"""
|
||||
schema = {}
|
||||
output = {"any": "value"}
|
||||
schema: SchemaData = {}
|
||||
output: SchemaData = {"any": "value"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@ -838,14 +739,14 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_schema_without_required(self):
|
||||
"""Test schema without required field"""
|
||||
schema = {
|
||||
schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"optional1": {"type": "string"},
|
||||
"optional2": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
output = {}
|
||||
output: SchemaData = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user