vibe: implement file structured output

This commit is contained in:
Stream 2026-02-01 02:47:28 +08:00
parent b6465327c1
commit b66db183c9
No known key found for this signature in database
GPG Key ID: 0D403F5A24E1C78B
8 changed files with 554 additions and 794 deletions

View File

@ -1,188 +1,190 @@
"""
File reference detection and conversion for structured output.
This module provides utilities to:
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
2. Convert file ID strings to File objects after LLM returns
"""
import uuid
from collections.abc import Mapping
from typing import Any
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
from core.file import File
from core.variables.segments import ArrayFileSegment, FileSegment
from factories.file_factory import build_from_mapping
FILE_REF_FORMAT = "dify-file-ref"
FILE_PATH_SCHEMA_TYPE = "file"
FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"}
FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)."
def is_file_ref_property(schema: dict) -> bool:
"""Check if a schema property is a file reference."""
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
if schema.get("type") == FILE_PATH_SCHEMA_TYPE:
return True
format_value = schema.get("format")
if not isinstance(format_value, str):
return False
normalized_format = format_value.lower().replace("_", "-")
return normalized_format in FILE_PATH_SCHEMA_FORMATS
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
"""
Recursively detect file reference fields in schema.
Args:
schema: JSON Schema to analyze
path: Current path in the schema (used for recursion)
Returns:
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
"""
file_ref_paths: list[str] = []
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
file_path_fields: list[str] = []
schema_type = schema.get("type")
if schema_type == "object":
for prop_name, prop_schema in schema.get("properties", {}).items():
current_path = f"{path}.{prop_name}" if path else prop_name
properties = schema.get("properties")
if isinstance(properties, Mapping):
properties_mapping = cast(Mapping[str, Any], properties)
for prop_name, prop_schema in properties_mapping.items():
if not isinstance(prop_schema, Mapping):
continue
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_ref_property(prop_schema):
file_ref_paths.append(current_path)
elif isinstance(prop_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
if is_file_path_property(prop_schema_mapping):
file_path_fields.append(current_path)
else:
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
elif schema_type == "array":
items_schema = schema.get("items", {})
items_schema = schema.get("items")
if not isinstance(items_schema, Mapping):
return file_path_fields
items_schema_mapping = cast(Mapping[str, Any], items_schema)
array_path = f"{path}[*]" if path else "[*]"
if is_file_ref_property(items_schema):
file_ref_paths.append(array_path)
elif isinstance(items_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
return file_ref_paths
def convert_file_refs_in_output(
output: Mapping[str, Any],
json_schema: Mapping[str, Any],
tenant_id: str,
) -> dict[str, Any]:
"""
Convert file ID strings to File objects based on schema.
Args:
output: The structured_output from LLM result
json_schema: The original JSON schema (to detect file ref fields)
tenant_id: Tenant ID for file lookup
Returns:
Output with file references converted to File objects
"""
file_ref_paths = detect_file_ref_fields(json_schema)
if not file_ref_paths:
return dict(output)
result = _deep_copy_dict(output)
for path in file_ref_paths:
_convert_path_in_place(result, path.split("."), tenant_id)
return result
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
"""Deep copy a mapping to a mutable dict."""
result: dict[str, Any] = {}
for key, value in obj.items():
if isinstance(value, Mapping):
result[key] = _deep_copy_dict(value)
elif isinstance(value, list):
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
if is_file_path_property(items_schema_mapping):
file_path_fields.append(array_path)
else:
result[key] = value
return result
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
return file_path_fields
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
"""Convert file refs at the given path in place, wrapping in Segment types."""
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
result = _deep_copy_value(schema)
if not isinstance(result, dict):
raise ValueError("structured_output_schema must be a JSON object")
result_dict = cast(dict[str, Any], result)
file_path_fields: list[str] = []
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
return result_dict, file_path_fields
def convert_sandbox_file_paths_in_output(
output: Mapping[str, Any],
file_path_fields: Sequence[str],
file_resolver: Callable[[str], File],
) -> tuple[dict[str, Any], list[File]]:
if not file_path_fields:
return dict(output), []
result = _deep_copy_value(output)
if not isinstance(result, dict):
raise ValueError("Structured output must be a JSON object")
result_dict = cast(dict[str, Any], result)
files: list[File] = []
for path in file_path_fields:
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
return result_dict, files
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
schema_type = schema.get("type")
if schema_type == "object":
properties = schema.get("properties")
if isinstance(properties, Mapping):
properties_mapping = cast(Mapping[str, Any], properties)
for prop_name, prop_schema in properties_mapping.items():
if not isinstance(prop_schema, dict):
continue
prop_schema_dict = cast(dict[str, Any], prop_schema)
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_path_property(prop_schema_dict):
_normalize_file_path_schema(prop_schema_dict)
file_path_fields.append(current_path)
else:
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
elif schema_type == "array":
items_schema = schema.get("items")
if not isinstance(items_schema, dict):
return
items_schema_dict = cast(dict[str, Any], items_schema)
array_path = f"{path}[*]" if path else "[*]"
if is_file_path_property(items_schema_dict):
_normalize_file_path_schema(items_schema_dict)
file_path_fields.append(array_path)
else:
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
schema["type"] = "string"
schema.pop("format", None)
description = schema.get("description", "")
if description:
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
else:
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
def _deep_copy_value(value: Any) -> Any:
if isinstance(value, Mapping):
mapping = cast(Mapping[str, Any], value)
return {key: _deep_copy_value(item) for key, item in mapping.items()}
if isinstance(value, list):
list_value = cast(list[Any], value)
return [_deep_copy_value(item) for item in list_value]
return value
def _convert_path_in_place(
obj: dict[str, Any],
path_parts: list[str],
file_resolver: Callable[[str], File],
files: list[File],
) -> None:
if not path_parts:
return
current = path_parts[0]
remaining = path_parts[1:]
# Handle array notation like "files[*]"
if current.endswith("[*]"):
key = current[:-3] if current != "[*]" else None
target = obj.get(key) if key else obj
key = current[:-3] if current != "[*]" else ""
target_value = obj.get(key) if key else obj
if isinstance(target, list):
if isinstance(target_value, list):
target_list = cast(list[Any], target_value)
if remaining:
# Nested array with remaining path - recurse into each item
for item in target:
for item in target_list:
if isinstance(item, dict):
_convert_path_in_place(item, remaining, tenant_id)
item_dict = cast(dict[str, Any], item)
_convert_path_in_place(item_dict, remaining, file_resolver, files)
else:
# Array of file IDs - convert all and wrap in ArrayFileSegment
files: list[File] = []
for item in target:
file = _convert_file_id(item, tenant_id)
if file is not None:
files.append(file)
# Replace the array with ArrayFileSegment
resolved_files: list[File] = []
for item in target_list:
if not isinstance(item, str):
raise ValueError("File path must be a string")
file = file_resolver(item)
files.append(file)
resolved_files.append(file)
if key:
obj[key] = ArrayFileSegment(value=files)
obj[key] = ArrayFileSegment(value=resolved_files)
return
if not remaining:
# Leaf node - convert the value and wrap in FileSegment
if current in obj:
file = _convert_file_id(obj[current], tenant_id)
if file is not None:
obj[current] = FileSegment(value=file)
else:
obj[current] = None
else:
# Recurse into nested object
if current in obj and isinstance(obj[current], dict):
_convert_path_in_place(obj[current], remaining, tenant_id)
if current not in obj:
return
value = obj[current]
if value is None:
obj[current] = None
return
if not isinstance(value, str):
raise ValueError("File path must be a string")
file = file_resolver(value)
files.append(file)
obj[current] = FileSegment(value=file)
return
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
"""
Convert a file ID string to a File object.
Tries multiple file sources in order:
1. ToolFile (files generated by tools/workflows)
2. UploadFile (files uploaded by users)
"""
if not isinstance(file_id, str):
return None
# Validate UUID format
try:
uuid.UUID(file_id)
except ValueError:
return None
# Try ToolFile first (files generated by tools/workflows)
try:
return build_from_mapping(
mapping={
"transfer_method": "tool_file",
"tool_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# Try UploadFile (files uploaded by users)
try:
return build_from_mapping(
mapping={
"transfer_method": "local_file",
"upload_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# File not found in any source
return None
if current in obj and isinstance(obj[current], dict):
_convert_path_in_place(obj[current], remaining, file_resolver, files)

View File

@ -8,7 +8,7 @@ import json_repair
from pydantic import BaseModel, TypeAdapter, ValidationError
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
from core.llm_generator.output_parser.file_ref import detect_file_path_fields
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
@ -55,12 +55,11 @@ def invoke_llm_with_structured_output(
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping | None = None,
model_parameters: Mapping[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput:
"""
Invoke large language model with structured output.
@ -78,14 +77,12 @@ def invoke_llm_with_structured_output(
:param stop: stop words
:param user: unique user id
:param callbacks: callbacks
:param tenant_id: tenant ID for file reference conversion. When provided and
json_schema contains file reference fields (format: "dify-file-ref"),
file IDs in the output will be automatically converted to File objects.
:return: full response or stream response chunk generator result
:return: response with structured output
"""
model_parameters_with_json_schema: dict[str, Any] = {
**(model_parameters or {}),
}
model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {})
if detect_file_path_fields(json_schema):
raise OutputParserError("Structured output file paths are only supported in sandbox mode.")
# Determine structured output strategy
@ -122,14 +119,6 @@ def invoke_llm_with_structured_output(
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
return LLMResultWithStructuredOutput(
structured_output=structured_output,
model=llm_result.model,
@ -147,12 +136,11 @@ def invoke_llm_with_pydantic_model(
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
model_parameters: Mapping | None = None,
model_parameters: Mapping[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> T:
"""
Invoke large language model with a Pydantic output model.
@ -160,11 +148,8 @@ def invoke_llm_with_pydantic_model(
This helper generates a JSON schema from the Pydantic model, invokes the
structured-output LLM path, and validates the result.
The stream parameter controls the underlying LLM invocation mode:
- stream=True (default): Uses streaming LLM call, consumes the generator internally
- stream=False: Uses non-streaming LLM call
In both cases, the function returns the validated Pydantic model directly.
The helper performs a non-streaming invocation and returns the validated
Pydantic model directly.
"""
json_schema = _schema_from_pydantic(output_model)
@ -179,7 +164,6 @@ def invoke_llm_with_pydantic_model(
stop=stop,
user=user,
callbacks=callbacks,
tenant_id=tenant_id,
)
structured_output = result.structured_output
@ -236,25 +220,27 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
return _parse_structured_output(content)
def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]:
"""Parse JSON from tool call arguments."""
if not arguments:
raise OutputParserError("Tool call arguments is empty")
try:
parsed = json.loads(arguments)
if not isinstance(parsed, dict):
parsed_any = json.loads(arguments)
if not isinstance(parsed_any, dict):
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
parsed = cast(dict[str, Any], parsed_any)
return parsed
except json.JSONDecodeError:
# Try to repair malformed JSON
repaired = json_repair.loads(arguments)
if not isinstance(repaired, dict):
repaired_any = json_repair.loads(arguments)
if not isinstance(repaired_any, dict):
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
repaired: dict[str, Any] = repaired_any
return repaired
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
def get_default_value_for_type(type_name: str | list[str] | None) -> Any:
"""Get default empty value for a JSON schema type."""
# Handle array of types (e.g., ["string", "null"])
if isinstance(type_name, list):
@ -311,7 +297,7 @@ def fill_defaults_from_schema(
# Create empty object and recursively fill its required fields
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
else:
result[prop_name] = _get_default_value_for_type(prop_type)
result[prop_name] = get_default_value_for_type(prop_type)
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
# Field exists and is an object, recursively fill nested required fields
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
@ -322,10 +308,10 @@ def fill_defaults_from_schema(
def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,
structured_output_schema: Mapping,
model_parameters: dict,
structured_output_schema: Mapping[str, Any],
model_parameters: dict[str, Any],
rules: list[ParameterRule],
):
) -> dict[str, Any]:
"""
Handle structured output for models with native JSON schema support.
@ -347,7 +333,7 @@ def _handle_native_json_schema(
return model_parameters
def _set_response_format(model_parameters: dict, rules: list):
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
"""
Set the appropriate response format parameter based on model rules.
@ -363,7 +349,7 @@ def _set_response_format(model_parameters: dict, rules: list):
def _handle_prompt_based_schema(
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any]
) -> list[PromptMessage]:
"""
Handle structured output for models without native JSON schema support.
@ -400,28 +386,27 @@ def _handle_prompt_based_schema(
return updated_prompt
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
structured_output: Mapping[str, Any] = {}
parsed: Mapping[str, Any] = {}
def _parse_structured_output(result_text: str) -> dict[str, Any]:
try:
parsed = TypeAdapter(Mapping).validate_json(result_text)
if not isinstance(parsed, dict):
raise OutputParserError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
parsed = TypeAdapter(dict[str, Any]).validate_json(result_text)
return parsed
except ValidationError:
# if the result_text is not a valid json, try to repair it
temp_parsed = json_repair.loads(result_text)
temp_parsed: Any = json_repair.loads(result_text)
if isinstance(temp_parsed, list):
temp_parsed_list = cast(list[Any], temp_parsed)
dict_items: list[dict[str, Any]] = []
for item in temp_parsed_list:
if isinstance(item, dict):
dict_items.append(cast(dict[str, Any], item))
temp_parsed = dict_items[0] if dict_items else {}
if not isinstance(temp_parsed, dict):
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
if isinstance(temp_parsed, list):
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
else:
raise OutputParserError(f"Failed to parse structured output: {result_text}")
structured_output = cast(dict, temp_parsed)
return structured_output
raise OutputParserError(f"Failed to parse structured output: {result_text}")
temp_parsed_dict = cast(dict[str, Any], temp_parsed)
return temp_parsed_dict
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]:
"""
Prepare JSON schema based on model requirements.
@ -433,54 +418,49 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
"""
# Deep copy to avoid modifying the original schema
processed_schema = dict(deepcopy(schema))
processed_schema = deepcopy(schema)
processed_schema_dict = dict(processed_schema)
# Convert boolean types to string types (common requirement)
convert_boolean_to_string(processed_schema)
convert_boolean_to_string(processed_schema_dict)
# Apply model-specific transformations
if SpecialModelType.GEMINI in model_schema.model:
remove_additional_properties(processed_schema)
return processed_schema
elif SpecialModelType.OLLAMA in provider:
return processed_schema
else:
# Default format with name field
return {"schema": processed_schema, "name": "llm_response"}
remove_additional_properties(processed_schema_dict)
return processed_schema_dict
if SpecialModelType.OLLAMA in provider:
return processed_schema_dict
# Default format with name field
return {"schema": processed_schema_dict, "name": "llm_response"}
def remove_additional_properties(schema: dict):
def remove_additional_properties(schema: dict[str, Any]) -> None:
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Remove additionalProperties at current level
schema.pop("additionalProperties", None)
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
remove_additional_properties(value)
remove_additional_properties(cast(dict[str, Any], value))
elif isinstance(value, list):
for item in value:
for item in cast(list[Any], value):
if isinstance(item, dict):
remove_additional_properties(item)
remove_additional_properties(cast(dict[str, Any], item))
def convert_boolean_to_string(schema: dict):
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
"""
Convert boolean type specifications to string in JSON schema.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Check for boolean type at current level
if schema.get("type") == "boolean":
schema["type"] = "string"
@ -488,8 +468,8 @@ def convert_boolean_to_string(schema: dict):
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
convert_boolean_to_string(value)
convert_boolean_to_string(cast(dict[str, Any], value))
elif isinstance(value, list):
for item in value:
for item in cast(list[Any], value):
if isinstance(item, dict):
convert_boolean_to_string(item)
convert_boolean_to_string(cast(dict[str, Any], item))

View File

@ -14,7 +14,8 @@ from core.skill.entities import ToolAccessPolicy
from core.skill.entities.tool_dependencies import ToolDependencies
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from core.virtual_environment.__base.helpers import pipeline
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.helpers import execute, pipeline
from ..bash.dify_cli import DifyCliConfig
from ..entities import DifyCli
@ -119,21 +120,6 @@ class SandboxBashSession:
return self._bash_tool
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
"""
Collect files from sandbox output directory and save them as ToolFiles.
Scans the specified output directory in sandbox, downloads each file,
saves it as a ToolFile, and returns a list of File objects. The File
objects will have valid tool_file_id that can be referenced by subsequent
nodes via structured output.
Args:
output_dir: Directory path in sandbox to scan for output files.
Defaults to "output" (relative to workspace).
Returns:
List of File objects representing the collected files.
"""
vm = self._sandbox.vm
collected_files: list[File] = []
@ -144,8 +130,6 @@ class SandboxBashSession:
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
return collected_files
tool_file_manager = ToolFileManager()
for file_state in file_states:
# Skip files that are too large
if file_state.size > MAX_OUTPUT_FILE_SIZE:
@ -162,47 +146,14 @@ class SandboxBashSession:
file_content = vm.download_file(file_state.path)
file_binary = file_content.getvalue()
# Determine mime type from extension
filename = os.path.basename(file_state.path)
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
# Save as ToolFile
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mime_type,
filename=filename,
)
# Determine file type from mime type
file_type = _get_file_type_from_mime(mime_type)
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
url = sign_tool_file(tool_file.id, extension)
# Create File object with tool_file_id as related_id
file_obj = File(
id=tool_file.id, # Use tool_file_id as the File id for easy reference
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=filename,
extension=extension,
mime_type=mime_type,
size=len(file_binary),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
file_obj = self._create_tool_file(filename=filename, file_binary=file_binary)
collected_files.append(file_obj)
logger.info(
"Collected sandbox output file: %s -> tool_file_id=%s",
file_state.path,
tool_file.id,
file_obj.id,
)
except Exception as exc:
@ -216,6 +167,85 @@ class SandboxBashSession:
)
return collected_files
def download_file(self, path: str) -> File:
path_kind = self._detect_path_kind(path)
if path_kind == "dir":
raise ValueError("Directory outputs are not supported")
if path_kind != "file":
raise ValueError(f"Sandbox file not found: {path}")
file_content = self._sandbox.vm.download_file(path)
file_binary = file_content.getvalue()
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
raise ValueError(f"Sandbox file exceeds size limit: {path}")
filename = os.path.basename(path) or "file"
return self._create_tool_file(filename=filename, file_binary=file_binary)
def _detect_path_kind(self, path: str) -> str:
script = r"""
import os
import sys
p = sys.argv[1]
if os.path.isdir(p):
print("dir")
raise SystemExit(0)
if os.path.isfile(p):
print("file")
raise SystemExit(0)
print("none")
raise SystemExit(2)
"""
try:
result = execute(
self._sandbox.vm,
[
"sh",
"-c",
'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"',
script,
path,
],
timeout=10,
error_message="Failed to inspect sandbox path",
)
except CommandExecutionError as exc:
raise ValueError(str(exc)) from exc
return result.stdout.decode("utf-8", errors="replace").strip()
def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File:
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
tool_file = ToolFileManager().create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mime_type,
filename=filename,
)
file_type = _get_file_type_from_mime(mime_type)
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
url = sign_tool_file(tool_file.id, extension)
return File(
id=tool_file.id,
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=filename,
extension=extension,
mime_type=mime_type,
size=len(file_binary),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
def _get_file_type_from_mime(mime_type: str) -> FileType:
"""Determine FileType from mime type."""

View File

@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str:
"""
Build a text description of generated files for inclusion in context.
The description includes file_id which can be used by subsequent nodes
to reference the files via structured output.
The description includes file_id for context; structured output file paths
are only supported in sandbox mode.
"""
if not files:
return ""

View File

@ -21,7 +21,11 @@ from core.app_assets.constants import AppAssetsAttrs
from core.file import FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
from core.llm_generator.output_parser.file_ref import (
adapt_schema_for_sandbox_file_paths,
convert_sandbox_file_paths_in_output,
detect_file_path_fields,
)
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
@ -297,13 +301,20 @@ class LLMNode(Node[LLMNodeData]):
)
structured_output_schema: Mapping[str, Any] | None
structured_output_file_paths: list[str] = []
if self.node_data.structured_output_enabled:
if not self.node_data.structured_output:
raise ValueError("structured_output_enabled is True but structured_output is not set")
structured_output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self.node_data.structured_output
)
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
if self.node_data.computer_use:
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
raw_schema
)
else:
if detect_file_path_fields(raw_schema):
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
structured_output_schema = raw_schema
else:
structured_output_schema = None
@ -319,8 +330,10 @@ class LLMNode(Node[LLMNodeData]):
stop=stop,
variable_pool=variable_pool,
tool_dependencies=tool_dependencies,
structured_output_schema=structured_output_schema
structured_output_schema=structured_output_schema,
structured_output_file_paths=structured_output_file_paths,
)
elif self.tool_call_enabled:
generator = self._invoke_llm_with_tools(
model_instance=model_instance,
@ -330,7 +343,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
structured_output_schema=structured_output_schema
structured_output_schema=structured_output_schema,
)
else:
# Use traditional LLM invocation
@ -532,8 +545,8 @@ class LLMNode(Node[LLMNodeData]):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
user=user_id,
tenant_id=tenant_id,
)
else:
request_start_time = time.perf_counter()
@ -1880,7 +1893,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
structured_output_schema: Mapping[str, Any] | None
structured_output_schema: Mapping[str, Any] | None,
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Invoke LLM with tools support (from Agent V2).
@ -1906,14 +1919,14 @@ class LLMNode(Node[LLMNodeData]):
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema
structured_output_schema=structured_output_schema,
)
# Run strategy
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or [])
stop=list(stop or []),
)
result = yield from self._process_tool_outputs(outputs)
@ -1927,10 +1940,12 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None,
variable_pool: VariablePool,
tool_dependencies: ToolDependencies | None,
structured_output_schema: Mapping[str, Any] | None
structured_output_schema: Mapping[str, Any] | None,
structured_output_file_paths: Sequence[str] | None,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]:
result: LLMGenerationData | None = None
sandbox_output_files: list[File] = []
structured_output_files: list[File] = []
# FIXME(Mairuis): Async processing for bash session.
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
@ -1948,17 +1963,31 @@ class LLMNode(Node[LLMNodeData]):
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema
structured_output_schema=structured_output_schema,
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or [])
stop=list(stop or []),
)
result = yield from self._process_tool_outputs(outputs)
if result and result.structured_output and structured_output_file_paths:
structured_output_payload = result.structured_output.structured_output or {}
try:
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
output=structured_output_payload,
file_path_fields=structured_output_file_paths,
file_resolver=session.download_file,
)
except ValueError as exc:
raise LLMNodeError(str(exc)) from exc
result = result.model_copy(
update={"structured_output": LLMStructuredOutput(structured_output=converted_output)}
)
# Collect output files from sandbox before session ends
# Files are saved as ToolFiles with valid tool_file_id for later reference
sandbox_output_files = session.collect_output_files()
@ -1971,7 +2000,7 @@ class LLMNode(Node[LLMNodeData]):
yield structured_output
# Merge sandbox output files into result
if sandbox_output_files:
if sandbox_output_files or structured_output_files:
result = LLMGenerationData(
text=result.text,
reasoning_contents=result.reasoning_contents,
@ -1979,7 +2008,7 @@ class LLMNode(Node[LLMNodeData]):
sequence=result.sequence,
usage=result.usage,
finish_reason=result.finish_reason,
files=result.files + sandbox_output_files,
files=result.files + sandbox_output_files + structured_output_files,
trace=result.trace,
)
@ -2056,9 +2085,12 @@ class LLMNode(Node[LLMNodeData]):
structured_output=self._node_data.structured_output or {},
)
tool_instances.extend(
build_agent_output_tools(tenant_id=self.tenant_id, invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
structured_output_schema=structured_output_schema)
build_agent_output_tools(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
structured_output_schema=structured_output_schema,
)
)
return tool_instances
@ -2564,15 +2596,7 @@ class LLMNode(Node[LLMNodeData]):
raise ValueError("No agent result found in tool outputs")
output_payload = agent_result.output
if isinstance(output_payload, dict):
state.aggregate.structured_output = LLMStructuredOutput(
structured_output=convert_file_refs_in_output(
output=output_payload,
json_schema=LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
),
tenant_id=self.tenant_id,
)
)
state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload)
state.aggregate.text = json.dumps(output_payload)
elif isinstance(output_payload, str):
state.aggregate.text = output_payload

View File

@ -126,9 +126,8 @@ workflow:
additionalProperties: false
properties:
image:
description: File ID (UUID) of the selected image
format: dify-file-ref
type: string
description: Sandbox file path of the selected image
type: file
required:
- image
type: object

View File

@ -1,269 +1,93 @@
"""
Unit tests for file reference detection and conversion.
Unit tests for sandbox file path detection and conversion.
"""
import uuid
from unittest.mock import MagicMock, patch
import pytest
from core.file import File, FileTransferMethod, FileType
from core.llm_generator.output_parser.file_ref import (
FILE_REF_FORMAT,
convert_file_refs_in_output,
detect_file_ref_fields,
is_file_ref_property,
FILE_PATH_DESCRIPTION_SUFFIX,
adapt_schema_for_sandbox_file_paths,
convert_sandbox_file_paths_in_output,
detect_file_path_fields,
is_file_path_property,
)
from core.variables.segments import ArrayFileSegment, FileSegment
class TestIsFileRefProperty:
"""Tests for is_file_ref_property function."""
def test_valid_file_ref(self):
schema = {"type": "string", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is True
def test_invalid_type(self):
schema = {"type": "number", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is False
def test_missing_format(self):
schema = {"type": "string"}
assert is_file_ref_property(schema) is False
def test_wrong_format(self):
schema = {"type": "string", "format": "uuid"}
assert is_file_ref_property(schema) is False
def _build_file(file_id: str) -> File:
return File(
id=file_id,
tenant_id="tenant_123",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
mime_type="image/png",
size=128,
related_id=file_id,
storage_key="sandbox/path",
)
class TestDetectFileRefFields:
"""Tests for detect_file_ref_fields function."""
class TestFilePathSchema:
def test_is_file_path_property(self):
assert is_file_path_property({"type": "file"}) is True
assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True
assert is_file_path_property({"type": "string"}) is False
def test_simple_file_ref(self):
def test_detect_file_path_fields(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
"image": {"type": "string", "format": "dify-file-ref"},
"files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}},
"meta": {"type": "object", "properties": {"doc": {"type": "file"}}},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["image"]
assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"}
def test_multiple_file_refs(self):
def test_adapt_schema_for_sandbox_file_paths(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
"document": {"type": "string", "format": FILE_REF_FORMAT},
"image": {"type": "string", "format": "dify-file-ref"},
"name": {"type": "string"},
},
}
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "document"}
adapted, fields = adapt_schema_for_sandbox_file_paths(schema)
def test_array_of_file_refs(self):
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["files[*]"]
def test_nested_file_ref(self):
schema = {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["data.image"]
def test_no_file_refs(self):
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_empty_schema(self):
schema = {}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_mixed_schema(self):
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"documents": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "documents[*]"}
assert set(fields) == {"image"}
adapted_image = adapted["properties"]["image"]
assert adapted_image["type"] == "string"
assert "format" not in adapted_image
assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"]
class TestConvertFileRefsInOutput:
"""Tests for convert_file_refs_in_output function."""
class TestConvertSandboxFilePaths:
def test_convert_sandbox_file_paths(self):
output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"}
@pytest.fixture
def mock_file(self):
"""Create a mock File object with all required attributes."""
file = MagicMock(spec=File)
file.type = FileType.IMAGE
file.transfer_method = FileTransferMethod.TOOL_FILE
file.related_id = "test-related-id"
file.remote_url = None
file.tenant_id = "tenant_123"
file.id = None
file.filename = "test.png"
file.extension = ".png"
file.mime_type = "image/png"
file.size = 1024
file.dify_model_identity = "__dify__file__"
return file
def resolver(path: str) -> File:
return _build_file(path)
@pytest.fixture
def mock_build_from_mapping(self, mock_file):
"""Mock the build_from_mapping function."""
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.return_value = mock_file
yield mock
converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver)
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
assert isinstance(converted["image"], FileSegment)
assert isinstance(converted["files"], ArrayFileSegment)
assert converted["name"] == "demo"
assert [file.id for file in files] == ["a.png", "b.png", "c.png"]
result = convert_file_refs_in_output(output, schema, "tenant_123")
def test_invalid_path_value_raises(self):
def resolver(path: str) -> File:
return _build_file(path)
# Result should be wrapped in FileSegment
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
mock_build_from_mapping.assert_called_once_with(
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
tenant_id="tenant_123",
)
with pytest.raises(ValueError):
convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver)
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
output = {"files": [file_id1, file_id2]}
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
def test_no_file_paths_returns_output(self):
output = {"name": "demo"}
converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file)
result = convert_file_refs_in_output(output, schema, "tenant_123")
# Result should be wrapped in ArrayFileSegment
assert isinstance(result["files"], ArrayFileSegment)
assert list(result["files"].value) == [mock_file, mock_file]
assert mock_build_from_mapping.call_count == 2
def test_no_conversion_without_file_refs(self):
output = {"name": "test", "count": 5}
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result == {"name": "test", "count": 5}
def test_invalid_uuid_returns_none(self):
output = {"image": "not-a-valid-uuid"}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_file_not_found_returns_none(self):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.side_effect = ValueError("File not found")
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"query": "search term", "image": file_id, "count": 10}
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["query"] == "search term"
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
assert result["count"] == 10
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
original = {"image": file_id}
output = dict(original)
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
convert_file_refs_in_output(output, schema, "tenant_123")
# Original should still contain the string ID
assert original["image"] == file_id
assert converted == output
assert files == []

View File

@ -1,4 +1,6 @@
from collections.abc import Mapping
from decimal import Decimal
from typing import Any, NotRequired, TypedDict
from unittest.mock import MagicMock, patch
import pytest
@ -6,16 +8,13 @@ from pydantic import BaseModel, ConfigDict
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import (
_get_default_value_for_type,
get_default_value_for_type,
fill_defaults_from_schema,
invoke_llm_with_pydantic_model,
invoke_llm_with_structured_output,
)
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMUsage,
)
@ -25,7 +24,30 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
ParameterRule,
ParameterType,
)
class StructuredOutputTestCase(TypedDict):
name: str
provider: str
model_name: str
support_structure_output: bool
stream: bool
json_schema: Mapping[str, Any]
expected_llm_response: LLMResult
expected_result_type: type[LLMResultWithStructuredOutput] | None
should_raise: bool
expected_error: NotRequired[type[OutputParserError]]
parameter_rules: NotRequired[list[ParameterRule]]
SchemaData = dict[str, Any]
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
@ -71,7 +93,7 @@ def get_model_instance() -> MagicMock:
def test_structured_output_parser():
"""Test cases for invoke_llm_with_structured_output function"""
testcases = [
testcases: list[StructuredOutputTestCase] = [
# Test case 1: Model with native structured output support, non-streaming
{
"name": "native_structured_output_non_streaming",
@ -88,39 +110,6 @@ def test_structured_output_parser():
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 2: Model with native structured output support, streaming
{
"name": "native_structured_output_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": True,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"name":'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
),
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' "test"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 3: Model without native structured output support, non-streaming
{
"name": "prompt_based_structured_output_non_streaming",
@ -137,78 +126,24 @@ def test_structured_output_parser():
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 4: Model without native structured output support, streaming
{
"name": "prompt_based_structured_output_streaming",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": True,
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"answer": "test'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
),
),
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' response"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 5: Streaming with list content
{
"name": "streaming_with_list_content",
"name": "non_streaming_with_list_content",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"data":'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"data":'),
TextPromptMessageContent(data=' "value"}'),
]
),
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data=' "value"}'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 6: Error case - non-string LLM response content (non-streaming)
@ -253,7 +188,13 @@ def test_structured_output_parser():
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
"parameter_rules": [
MagicMock(name="response_format", options=["json_schema"], required=False),
ParameterRule(
name="response_format",
label=I18nObject(en_US="response_format"),
type=ParameterType.STRING,
required=False,
options=["json_schema"],
),
],
"expected_llm_response": LLMResult(
model="gpt-4o",
@ -272,7 +213,13 @@ def test_structured_output_parser():
"stream": False,
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
"parameter_rules": [
MagicMock(name="response_format", options=["JSON"], required=False),
ParameterRule(
name="response_format",
label=I18nObject(en_US="response_format"),
type=ParameterType.STRING,
required=False,
options=["JSON"],
),
],
"expected_llm_response": LLMResult(
model="claude-3-sonnet",
@ -285,89 +232,72 @@ def test_structured_output_parser():
]
for case in testcases:
# Setup model entity
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
provider = case["provider"]
model_name = case["model_name"]
support_structure_output = case["support_structure_output"]
json_schema = case["json_schema"]
stream = case["stream"]
# Add parameter rules if specified
if "parameter_rules" in case:
model_schema.parameter_rules = case["parameter_rules"]
model_schema = get_model_entity(provider, model_name, support_structure_output)
parameter_rules = case.get("parameter_rules")
if parameter_rules is not None:
model_schema.parameter_rules = parameter_rules
# Setup model instance
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = case["expected_llm_response"]
# Setup prompt messages
prompt_messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Generate a response according to the schema."),
]
if case["should_raise"]:
# Test error cases
with pytest.raises(case["expected_error"]): # noqa: PT012
if case["stream"]:
expected_error = case.get("expected_error", OutputParserError)
with pytest.raises(expected_error): # noqa: PT012
if stream:
result_generator = invoke_llm_with_structured_output(
provider=case["provider"],
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
json_schema=json_schema,
)
# Consume the generator to trigger the error
list(result_generator)
else:
invoke_llm_with_structured_output(
provider=case["provider"],
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
json_schema=json_schema,
)
else:
# Test successful cases
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Configure json_repair mock for cases that need it
if case["name"] == "json_repair_scenario":
mock_json_repair.return_value = {"name": "test"}
result = invoke_llm_with_structured_output(
provider=case["provider"],
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=case["json_schema"],
json_schema=json_schema,
model_parameters={"temperature": 0.7, "max_tokens": 100},
user="test_user",
)
if case["expected_result_type"] == "generator":
# Test streaming results
assert hasattr(result, "__iter__")
chunks = list(result)
assert len(chunks) > 0
expected_result_type = case["expected_result_type"]
assert expected_result_type is not None
assert isinstance(result, expected_result_type)
assert result.model == model_name
assert result.structured_output is not None
assert isinstance(result.structured_output, dict)
# Verify all chunks are LLMResultChunkWithStructuredOutput
for chunk in chunks[:-1]: # All except last
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
assert chunk.model == case["model_name"]
# Last chunk should have structured output
last_chunk = chunks[-1]
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
assert last_chunk.structured_output is not None
assert isinstance(last_chunk.structured_output, dict)
else:
# Test non-streaming results
assert isinstance(result, case["expected_result_type"])
assert result.model == case["model_name"]
assert result.structured_output is not None
assert isinstance(result.structured_output, dict)
# Verify model_instance.invoke_llm was called with correct parameters
model_instance.invoke_llm.assert_called_once()
call_args = model_instance.invoke_llm.call_args
assert call_args.kwargs["stream"] == case["stream"]
assert call_args.kwargs["stream"] == stream
assert call_args.kwargs["user"] == "test_user"
assert "temperature" in call_args.kwargs["model_parameters"]
assert "max_tokens" in call_args.kwargs["model_parameters"]
@ -376,45 +306,32 @@ def test_structured_output_parser():
def test_parse_structured_output_edge_cases():
"""Test edge cases for structured output parsing"""
# Test case with list that contains dict (reasoning model scenario)
testcase_list_with_dict = {
"name": "list_with_dict_parsing",
"provider": "deepseek",
"model_name": "deepseek-r1",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="deepseek-r1",
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
}
# Setup for list parsing test
model_schema = get_model_entity(
testcase_list_with_dict["provider"],
testcase_list_with_dict["model_name"],
testcase_list_with_dict["support_structure_output"],
provider = "deepseek"
model_name = "deepseek-r1"
support_structure_output = False
json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}}
expected_llm_response = LLMResult(
model="deepseek-r1",
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
)
model_schema = get_model_entity(provider, model_name, support_structure_output)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
model_instance.invoke_llm.return_value = expected_llm_response
prompt_messages = [UserPromptMessage(content="Test reasoning")]
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Mock json_repair to return a list with dict
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
result = invoke_llm_with_structured_output(
provider=testcase_list_with_dict["provider"],
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=testcase_list_with_dict["json_schema"],
json_schema=json_schema,
)
assert isinstance(result, LLMResultWithStructuredOutput)
@ -424,18 +341,16 @@ def test_parse_structured_output_edge_cases():
def test_model_specific_schema_preparation():
"""Test schema preparation for different model types"""
# Test Gemini model
gemini_case = {
"provider": "google",
"model_name": "gemini-pro",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
provider = "google"
model_name = "gemini-pro"
support_structure_output = True
json_schema: SchemaData = {
"type": "object",
"properties": {"result": {"type": "boolean"}},
"additionalProperties": False,
}
model_schema = get_model_entity(
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
)
model_schema = get_model_entity(provider, model_name, support_structure_output)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
@ -447,11 +362,11 @@ def test_model_specific_schema_preparation():
prompt_messages = [UserPromptMessage(content="Test")]
result = invoke_llm_with_structured_output(
provider=gemini_case["provider"],
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=gemini_case["json_schema"],
json_schema=json_schema,
)
assert isinstance(result, LLMResultWithStructuredOutput)
@ -493,40 +408,26 @@ def test_structured_output_with_pydantic_model_non_streaming():
assert result.name == "test"
def test_structured_output_with_pydantic_model_streaming():
def test_structured_output_with_pydantic_model_list_content():
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
model_instance = get_model_instance()
def mock_streaming_response():
yield LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"name":'),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=2),
),
)
yield LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' "test"}'),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
),
)
model_instance.invoke_llm.return_value = mock_streaming_response()
model_instance.invoke_llm.return_value = LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"name":'),
TextPromptMessageContent(data=' "test"}'),
]
),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
)
result = invoke_llm_with_pydantic_model(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
output_model=ExampleOutput
output_model=ExampleOutput,
)
assert isinstance(result, ExampleOutput)
@ -548,51 +449,51 @@ def test_structured_output_with_pydantic_model_validation_error():
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="test")],
output_model=ExampleOutput
output_model=ExampleOutput,
)
class TestGetDefaultValueForType:
"""Test cases for _get_default_value_for_type function"""
"""Test cases for get_default_value_for_type function"""
def test_string_type(self):
assert _get_default_value_for_type("string") == ""
assert get_default_value_for_type("string") == ""
def test_object_type(self):
assert _get_default_value_for_type("object") == {}
assert get_default_value_for_type("object") == {}
def test_array_type(self):
assert _get_default_value_for_type("array") == []
assert get_default_value_for_type("array") == []
def test_number_type(self):
assert _get_default_value_for_type("number") == 0
assert get_default_value_for_type("number") == 0
def test_integer_type(self):
assert _get_default_value_for_type("integer") == 0
assert get_default_value_for_type("integer") == 0
def test_boolean_type(self):
assert _get_default_value_for_type("boolean") is False
assert get_default_value_for_type("boolean") is False
def test_null_type(self):
assert _get_default_value_for_type("null") is None
assert get_default_value_for_type("null") is None
def test_none_type(self):
assert _get_default_value_for_type(None) is None
assert get_default_value_for_type(None) is None
def test_unknown_type(self):
assert _get_default_value_for_type("unknown") is None
assert get_default_value_for_type("unknown") is None
def test_union_type_string_null(self):
# ["string", "null"] should return "" (first non-null type)
assert _get_default_value_for_type(["string", "null"]) == ""
assert get_default_value_for_type(["string", "null"]) == ""
def test_union_type_null_first(self):
# ["null", "integer"] should return 0 (first non-null type)
assert _get_default_value_for_type(["null", "integer"]) == 0
assert get_default_value_for_type(["null", "integer"]) == 0
def test_union_type_only_null(self):
# ["null"] should return None
assert _get_default_value_for_type(["null"]) is None
assert get_default_value_for_type(["null"]) is None
class TestFillDefaultsFromSchema:
@ -600,7 +501,7 @@ class TestFillDefaultsFromSchema:
def test_simple_required_fields(self):
"""Test filling simple required fields"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"name": {"type": "string"},
@ -609,7 +510,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["name", "age"],
}
output = {"name": "Alice"}
output: SchemaData = {"name": "Alice"}
result = fill_defaults_from_schema(output, schema)
@ -619,7 +520,7 @@ class TestFillDefaultsFromSchema:
def test_non_required_fields_not_filled(self):
"""Test that non-required fields are not filled"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"required_field": {"type": "string"},
@ -627,7 +528,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["required_field"],
}
output = {}
output: SchemaData = {}
result = fill_defaults_from_schema(output, schema)
@ -636,7 +537,7 @@ class TestFillDefaultsFromSchema:
def test_nested_object_required_fields(self):
"""Test filling nested object required fields"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"user": {
@ -659,7 +560,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["user"],
}
output = {
output: SchemaData = {
"user": {
"name": "Alice",
"address": {
@ -684,7 +585,7 @@ class TestFillDefaultsFromSchema:
def test_missing_nested_object_created(self):
"""Test that missing required nested objects are created"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"metadata": {
@ -698,7 +599,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["metadata"],
}
output = {}
output: SchemaData = {}
result = fill_defaults_from_schema(output, schema)
@ -710,7 +611,7 @@ class TestFillDefaultsFromSchema:
def test_all_types_default_values(self):
"""Test default values for all types"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"str_field": {"type": "string"},
@ -722,7 +623,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
}
output = {}
output: SchemaData = {}
result = fill_defaults_from_schema(output, schema)
@ -737,7 +638,7 @@ class TestFillDefaultsFromSchema:
def test_existing_values_preserved(self):
"""Test that existing values are not overwritten"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"name": {"type": "string"},
@ -745,7 +646,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["name", "count"],
}
output = {"name": "Bob", "count": 42}
output: SchemaData = {"name": "Bob", "count": 42}
result = fill_defaults_from_schema(output, schema)
@ -753,7 +654,7 @@ class TestFillDefaultsFromSchema:
def test_complex_nested_structure(self):
"""Test complex nested structure with multiple levels"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"user": {
@ -789,7 +690,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["user", "tags", "metadata", "is_active"],
}
output = {
output: SchemaData = {
"user": {
"name": "Alice",
"age": 25,
@ -829,8 +730,8 @@ class TestFillDefaultsFromSchema:
def test_empty_schema(self):
"""Test with empty schema"""
schema = {}
output = {"any": "value"}
schema: SchemaData = {}
output: SchemaData = {"any": "value"}
result = fill_defaults_from_schema(output, schema)
@ -838,14 +739,14 @@ class TestFillDefaultsFromSchema:
def test_schema_without_required(self):
"""Test schema without required field"""
schema = {
schema: SchemaData = {
"type": "object",
"properties": {
"optional1": {"type": "string"},
"optional2": {"type": "integer"},
},
}
output = {}
output: SchemaData = {}
result = fill_defaults_from_schema(output, schema)