From b66db183c922e6470a35a07b8500731647b6bce5 Mon Sep 17 00:00:00 2001 From: Stream Date: Sun, 1 Feb 2026 02:47:28 +0800 Subject: [PATCH] vibe: implement file structured output --- .../llm_generator/output_parser/file_ref.py | 308 +++++++------- .../output_parser/structured_output.py | 134 +++--- api/core/sandbox/bash/session.py | 136 +++--- api/core/workflow/nodes/llm/llm_utils.py | 4 +- api/core/workflow/nodes/llm/node.py | 78 ++-- api/tests/fixtures/file output schema.yml | 5 +- .../output_parser/test_file_ref.py | 290 +++---------- .../test_structured_output_parser.py | 393 +++++++----------- 8 files changed, 554 insertions(+), 794 deletions(-) diff --git a/api/core/llm_generator/output_parser/file_ref.py b/api/core/llm_generator/output_parser/file_ref.py index 83489e6a79..74e47570c6 100644 --- a/api/core/llm_generator/output_parser/file_ref.py +++ b/api/core/llm_generator/output_parser/file_ref.py @@ -1,188 +1,190 @@ -""" -File reference detection and conversion for structured output. - -This module provides utilities to: -1. Detect file reference fields in JSON Schema (format: "dify-file-ref") -2. Convert file ID strings to File objects after LLM returns -""" - -import uuid -from collections.abc import Mapping -from typing import Any +from collections.abc import Callable, Mapping, Sequence +from typing import Any, cast from core.file import File from core.variables.segments import ArrayFileSegment, FileSegment -from factories.file_factory import build_from_mapping -FILE_REF_FORMAT = "dify-file-ref" +FILE_PATH_SCHEMA_TYPE = "file" +FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"} +FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)." -def is_file_ref_property(schema: dict) -> bool: - """Check if a schema property is a file reference.""" - return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT +def is_file_path_property(schema: Mapping[str, Any]) -> bool: + if schema.get("type") == FILE_PATH_SCHEMA_TYPE: + return True + format_value = schema.get("format") + if not isinstance(format_value, str): + return False + normalized_format = format_value.lower().replace("_", "-") + return normalized_format in FILE_PATH_SCHEMA_FORMATS -def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]: - """ - Recursively detect file reference fields in schema. - - Args: - schema: JSON Schema to analyze - path: Current path in the schema (used for recursion) - - Returns: - List of JSON paths containing file refs, e.g., ["image_id", "files[*]"] - """ - file_ref_paths: list[str] = [] +def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]: + file_path_fields: list[str] = [] schema_type = schema.get("type") if schema_type == "object": - for prop_name, prop_schema in schema.get("properties", {}).items(): - current_path = f"{path}.{prop_name}" if path else prop_name + properties = schema.get("properties") + if isinstance(properties, Mapping): + properties_mapping = cast(Mapping[str, Any], properties) + for prop_name, prop_schema in properties_mapping.items(): + if not isinstance(prop_schema, Mapping): + continue + prop_schema_mapping = cast(Mapping[str, Any], prop_schema) + current_path = f"{path}.{prop_name}" if path else prop_name - if is_file_ref_property(prop_schema): - file_ref_paths.append(current_path) - elif isinstance(prop_schema, dict): - file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path)) + if is_file_path_property(prop_schema_mapping): + file_path_fields.append(current_path) + else: + file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path)) elif schema_type == "array": - items_schema = schema.get("items", {}) + items_schema = schema.get("items") + if not isinstance(items_schema, Mapping): + return file_path_fields + items_schema_mapping = cast(Mapping[str, Any], items_schema) array_path = f"{path}[*]" if path else "[*]" - if is_file_ref_property(items_schema): - file_ref_paths.append(array_path) - elif isinstance(items_schema, dict): - file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path)) - - return file_ref_paths - - -def convert_file_refs_in_output( - output: Mapping[str, Any], - json_schema: Mapping[str, Any], - tenant_id: str, -) -> dict[str, Any]: - """ - Convert file ID strings to File objects based on schema. - - Args: - output: The structured_output from LLM result - json_schema: The original JSON schema (to detect file ref fields) - tenant_id: Tenant ID for file lookup - - Returns: - Output with file references converted to File objects - """ - file_ref_paths = detect_file_ref_fields(json_schema) - if not file_ref_paths: - return dict(output) - - result = _deep_copy_dict(output) - - for path in file_ref_paths: - _convert_path_in_place(result, path.split("."), tenant_id) - - return result - - -def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]: - """Deep copy a mapping to a mutable dict.""" - result: dict[str, Any] = {} - for key, value in obj.items(): - if isinstance(value, Mapping): - result[key] = _deep_copy_dict(value) - elif isinstance(value, list): - result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value] + if is_file_path_property(items_schema_mapping): + file_path_fields.append(array_path) else: - result[key] = value - return result + file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path)) + + return file_path_fields -def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None: - """Convert file refs at the given path in place, wrapping in Segment types.""" +def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]: + result = _deep_copy_value(schema) + if not isinstance(result, dict): + raise ValueError("structured_output_schema must be a JSON object") + result_dict = cast(dict[str, Any], result) + + file_path_fields: list[str] = [] + _adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields) + return result_dict, file_path_fields + + +def convert_sandbox_file_paths_in_output( + output: Mapping[str, Any], + file_path_fields: Sequence[str], + file_resolver: Callable[[str], File], +) -> tuple[dict[str, Any], list[File]]: + if not file_path_fields: + return dict(output), [] + + result = _deep_copy_value(output) + if not isinstance(result, dict): + raise ValueError("Structured output must be a JSON object") + result_dict = cast(dict[str, Any], result) + + files: list[File] = [] + for path in file_path_fields: + _convert_path_in_place(result_dict, path.split("."), file_resolver, files) + + return result_dict, files + + +def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None: + schema_type = schema.get("type") + + if schema_type == "object": + properties = schema.get("properties") + if isinstance(properties, Mapping): + properties_mapping = cast(Mapping[str, Any], properties) + for prop_name, prop_schema in properties_mapping.items(): + if not isinstance(prop_schema, dict): + continue + prop_schema_dict = cast(dict[str, Any], prop_schema) + current_path = f"{path}.{prop_name}" if path else prop_name + + if is_file_path_property(prop_schema_dict): + _normalize_file_path_schema(prop_schema_dict) + file_path_fields.append(current_path) + else: + _adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields) + + elif schema_type == "array": + items_schema = schema.get("items") + if not isinstance(items_schema, dict): + return + items_schema_dict = cast(dict[str, Any], items_schema) + array_path = f"{path}[*]" if path else "[*]" + + if is_file_path_property(items_schema_dict): + _normalize_file_path_schema(items_schema_dict) + file_path_fields.append(array_path) + else: + _adapt_schema_in_place(items_schema_dict, array_path, file_path_fields) + + +def _normalize_file_path_schema(schema: dict[str, Any]) -> None: + schema["type"] = "string" + schema.pop("format", None) + description = schema.get("description", "") + if description: + schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}" + else: + schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX + + +def _deep_copy_value(value: Any) -> Any: + if isinstance(value, Mapping): + mapping = cast(Mapping[str, Any], value) + return {key: _deep_copy_value(item) for key, item in mapping.items()} + if isinstance(value, list): + list_value = cast(list[Any], value) + return [_deep_copy_value(item) for item in list_value] + return value + + +def _convert_path_in_place( + obj: dict[str, Any], + path_parts: list[str], + file_resolver: Callable[[str], File], + files: list[File], +) -> None: if not path_parts: return current = path_parts[0] remaining = path_parts[1:] - # Handle array notation like "files[*]" if current.endswith("[*]"): - key = current[:-3] if current != "[*]" else None - target = obj.get(key) if key else obj + key = current[:-3] if current != "[*]" else "" + target_value = obj.get(key) if key else obj - if isinstance(target, list): + if isinstance(target_value, list): + target_list = cast(list[Any], target_value) if remaining: - # Nested array with remaining path - recurse into each item - for item in target: + for item in target_list: if isinstance(item, dict): - _convert_path_in_place(item, remaining, tenant_id) + item_dict = cast(dict[str, Any], item) + _convert_path_in_place(item_dict, remaining, file_resolver, files) else: - # Array of file IDs - convert all and wrap in ArrayFileSegment - files: list[File] = [] - for item in target: - file = _convert_file_id(item, tenant_id) - if file is not None: - files.append(file) - # Replace the array with ArrayFileSegment + resolved_files: list[File] = [] + for item in target_list: + if not isinstance(item, str): + raise ValueError("File path must be a string") + file = file_resolver(item) + files.append(file) + resolved_files.append(file) if key: - obj[key] = ArrayFileSegment(value=files) + obj[key] = ArrayFileSegment(value=resolved_files) return if not remaining: - # Leaf node - convert the value and wrap in FileSegment - if current in obj: - file = _convert_file_id(obj[current], tenant_id) - if file is not None: - obj[current] = FileSegment(value=file) - else: - obj[current] = None - else: - # Recurse into nested object - if current in obj and isinstance(obj[current], dict): - _convert_path_in_place(obj[current], remaining, tenant_id) + if current not in obj: + return + value = obj[current] + if value is None: + obj[current] = None + return + if not isinstance(value, str): + raise ValueError("File path must be a string") + file = file_resolver(value) + files.append(file) + obj[current] = FileSegment(value=file) + return - -def _convert_file_id(file_id: Any, tenant_id: str) -> File | None: - """ - Convert a file ID string to a File object. - - Tries multiple file sources in order: - 1. ToolFile (files generated by tools/workflows) - 2. UploadFile (files uploaded by users) - """ - if not isinstance(file_id, str): - return None - - # Validate UUID format - try: - uuid.UUID(file_id) - except ValueError: - return None - - # Try ToolFile first (files generated by tools/workflows) - try: - return build_from_mapping( - mapping={ - "transfer_method": "tool_file", - "tool_file_id": file_id, - }, - tenant_id=tenant_id, - ) - except ValueError: - pass - - # Try UploadFile (files uploaded by users) - try: - return build_from_mapping( - mapping={ - "transfer_method": "local_file", - "upload_file_id": file_id, - }, - tenant_id=tenant_id, - ) - except ValueError: - pass - - # File not found in any source - return None + if current in obj and isinstance(obj[current], dict): + _convert_path_in_place(obj[current], remaining, file_resolver, files) diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index a069f0409c..c790feda53 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -8,7 +8,7 @@ import json_repair from pydantic import BaseModel, TypeAdapter, ValidationError from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output +from core.llm_generator.output_parser.file_ref import detect_file_path_fields from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT from core.model_manager import ModelInstance from core.model_runtime.callbacks.base_callback import Callback @@ -55,12 +55,11 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Mapping | None = None, + model_parameters: Mapping[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, user: str | None = None, callbacks: list[Callback] | None = None, - tenant_id: str | None = None, ) -> LLMResultWithStructuredOutput: """ Invoke large language model with structured output. @@ -78,14 +77,12 @@ def invoke_llm_with_structured_output( :param stop: stop words :param user: unique user id :param callbacks: callbacks - :param tenant_id: tenant ID for file reference conversion. When provided and - json_schema contains file reference fields (format: "dify-file-ref"), - file IDs in the output will be automatically converted to File objects. - :return: full response or stream response chunk generator result + :return: response with structured output """ - model_parameters_with_json_schema: dict[str, Any] = { - **(model_parameters or {}), - } + model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {}) + + if detect_file_path_fields(json_schema): + raise OutputParserError("Structured output file paths are only supported in sandbox mode.") # Determine structured output strategy @@ -122,14 +119,6 @@ def invoke_llm_with_structured_output( # Fill missing fields with default values structured_output = fill_defaults_from_schema(structured_output, json_schema) - # Convert file references if tenant_id is provided - if tenant_id is not None: - structured_output = convert_file_refs_in_output( - output=structured_output, - json_schema=json_schema, - tenant_id=tenant_id, - ) - return LLMResultWithStructuredOutput( structured_output=structured_output, model=llm_result.model, @@ -147,12 +136,11 @@ def invoke_llm_with_pydantic_model( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], output_model: type[T], - model_parameters: Mapping | None = None, + model_parameters: Mapping[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, user: str | None = None, callbacks: list[Callback] | None = None, - tenant_id: str | None = None, ) -> T: """ Invoke large language model with a Pydantic output model. @@ -160,11 +148,8 @@ def invoke_llm_with_pydantic_model( This helper generates a JSON schema from the Pydantic model, invokes the structured-output LLM path, and validates the result. - The stream parameter controls the underlying LLM invocation mode: - - stream=True (default): Uses streaming LLM call, consumes the generator internally - - stream=False: Uses non-streaming LLM call - - In both cases, the function returns the validated Pydantic model directly. + The helper performs a non-streaming invocation and returns the validated + Pydantic model directly. """ json_schema = _schema_from_pydantic(output_model) @@ -179,7 +164,6 @@ def invoke_llm_with_pydantic_model( stop=stop, user=user, callbacks=callbacks, - tenant_id=tenant_id, ) structured_output = result.structured_output @@ -236,25 +220,27 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]: return _parse_structured_output(content) -def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]: +def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]: """Parse JSON from tool call arguments.""" if not arguments: raise OutputParserError("Tool call arguments is empty") try: - parsed = json.loads(arguments) - if not isinstance(parsed, dict): + parsed_any = json.loads(arguments) + if not isinstance(parsed_any, dict): raise OutputParserError(f"Tool call arguments is not a dict: {arguments}") + parsed = cast(dict[str, Any], parsed_any) return parsed except json.JSONDecodeError: # Try to repair malformed JSON - repaired = json_repair.loads(arguments) - if not isinstance(repaired, dict): + repaired_any = json_repair.loads(arguments) + if not isinstance(repaired_any, dict): raise OutputParserError(f"Failed to parse tool call arguments: {arguments}") + repaired: dict[str, Any] = repaired_any return repaired -def _get_default_value_for_type(type_name: str | list[str] | None) -> Any: +def get_default_value_for_type(type_name: str | list[str] | None) -> Any: """Get default empty value for a JSON schema type.""" # Handle array of types (e.g., ["string", "null"]) if isinstance(type_name, list): @@ -311,7 +297,7 @@ def fill_defaults_from_schema( # Create empty object and recursively fill its required fields result[prop_name] = fill_defaults_from_schema({}, prop_schema) else: - result[prop_name] = _get_default_value_for_type(prop_type) + result[prop_name] = get_default_value_for_type(prop_type) elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema: # Field exists and is an object, recursively fill nested required fields result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema) @@ -322,10 +308,10 @@ def fill_defaults_from_schema( def _handle_native_json_schema( provider: str, model_schema: AIModelEntity, - structured_output_schema: Mapping, - model_parameters: dict, + structured_output_schema: Mapping[str, Any], + model_parameters: dict[str, Any], rules: list[ParameterRule], -): +) -> dict[str, Any]: """ Handle structured output for models with native JSON schema support. @@ -347,7 +333,7 @@ def _handle_native_json_schema( return model_parameters -def _set_response_format(model_parameters: dict, rules: list): +def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None: """ Set the appropriate response format parameter based on model rules. @@ -363,7 +349,7 @@ def _set_response_format(model_parameters: dict, rules: list): def _handle_prompt_based_schema( - prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping + prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any] ) -> list[PromptMessage]: """ Handle structured output for models without native JSON schema support. @@ -400,28 +386,27 @@ def _handle_prompt_based_schema( return updated_prompt -def _parse_structured_output(result_text: str) -> Mapping[str, Any]: - structured_output: Mapping[str, Any] = {} - parsed: Mapping[str, Any] = {} +def _parse_structured_output(result_text: str) -> dict[str, Any]: try: - parsed = TypeAdapter(Mapping).validate_json(result_text) - if not isinstance(parsed, dict): - raise OutputParserError(f"Failed to parse structured output: {result_text}") - structured_output = parsed + parsed = TypeAdapter(dict[str, Any]).validate_json(result_text) + return parsed except ValidationError: # if the result_text is not a valid json, try to repair it - temp_parsed = json_repair.loads(result_text) + temp_parsed: Any = json_repair.loads(result_text) + if isinstance(temp_parsed, list): + temp_parsed_list = cast(list[Any], temp_parsed) + dict_items: list[dict[str, Any]] = [] + for item in temp_parsed_list: + if isinstance(item, dict): + dict_items.append(cast(dict[str, Any], item)) + temp_parsed = dict_items[0] if dict_items else {} if not isinstance(temp_parsed, dict): - # handle reasoning model like deepseek-r1 got '\n\n\n' prefix - if isinstance(temp_parsed, list): - temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {}) - else: - raise OutputParserError(f"Failed to parse structured output: {result_text}") - structured_output = cast(dict, temp_parsed) - return structured_output + raise OutputParserError(f"Failed to parse structured output: {result_text}") + temp_parsed_dict = cast(dict[str, Any], temp_parsed) + return temp_parsed_dict -def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping): +def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]: """ Prepare JSON schema based on model requirements. @@ -433,54 +418,49 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema """ # Deep copy to avoid modifying the original schema - processed_schema = dict(deepcopy(schema)) + processed_schema = deepcopy(schema) + processed_schema_dict = dict(processed_schema) # Convert boolean types to string types (common requirement) - convert_boolean_to_string(processed_schema) + convert_boolean_to_string(processed_schema_dict) # Apply model-specific transformations if SpecialModelType.GEMINI in model_schema.model: - remove_additional_properties(processed_schema) - return processed_schema - elif SpecialModelType.OLLAMA in provider: - return processed_schema - else: - # Default format with name field - return {"schema": processed_schema, "name": "llm_response"} + remove_additional_properties(processed_schema_dict) + return processed_schema_dict + if SpecialModelType.OLLAMA in provider: + return processed_schema_dict + + # Default format with name field + return {"schema": processed_schema_dict, "name": "llm_response"} -def remove_additional_properties(schema: dict): +def remove_additional_properties(schema: dict[str, Any]) -> None: """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. :param schema: JSON schema to modify in-place """ - if not isinstance(schema, dict): - return - # Remove additionalProperties at current level schema.pop("additionalProperties", None) # Process nested structures recursively for value in schema.values(): if isinstance(value, dict): - remove_additional_properties(value) + remove_additional_properties(cast(dict[str, Any], value)) elif isinstance(value, list): - for item in value: + for item in cast(list[Any], value): if isinstance(item, dict): - remove_additional_properties(item) + remove_additional_properties(cast(dict[str, Any], item)) -def convert_boolean_to_string(schema: dict): +def convert_boolean_to_string(schema: dict[str, Any]) -> None: """ Convert boolean type specifications to string in JSON schema. :param schema: JSON schema to modify in-place """ - if not isinstance(schema, dict): - return - # Check for boolean type at current level if schema.get("type") == "boolean": schema["type"] = "string" @@ -488,8 +468,8 @@ def convert_boolean_to_string(schema: dict): # Process nested dictionaries and lists recursively for value in schema.values(): if isinstance(value, dict): - convert_boolean_to_string(value) + convert_boolean_to_string(cast(dict[str, Any], value)) elif isinstance(value, list): - for item in value: + for item in cast(list[Any], value): if isinstance(item, dict): - convert_boolean_to_string(item) + convert_boolean_to_string(cast(dict[str, Any], item)) diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py index 57e790b2c2..60691790a2 100644 --- a/api/core/sandbox/bash/session.py +++ b/api/core/sandbox/bash/session.py @@ -14,7 +14,8 @@ from core.skill.entities import ToolAccessPolicy from core.skill.entities.tool_dependencies import ToolDependencies from core.tools.signature import sign_tool_file from core.tools.tool_file_manager import ToolFileManager -from core.virtual_environment.__base.helpers import pipeline +from core.virtual_environment.__base.exec import CommandExecutionError +from core.virtual_environment.__base.helpers import execute, pipeline from ..bash.dify_cli import DifyCliConfig from ..entities import DifyCli @@ -119,21 +120,6 @@ class SandboxBashSession: return self._bash_tool def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]: - """ - Collect files from sandbox output directory and save them as ToolFiles. - - Scans the specified output directory in sandbox, downloads each file, - saves it as a ToolFile, and returns a list of File objects. The File - objects will have valid tool_file_id that can be referenced by subsequent - nodes via structured output. - - Args: - output_dir: Directory path in sandbox to scan for output files. - Defaults to "output" (relative to workspace). - - Returns: - List of File objects representing the collected files. - """ vm = self._sandbox.vm collected_files: list[File] = [] @@ -144,8 +130,6 @@ class SandboxBashSession: logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc) return collected_files - tool_file_manager = ToolFileManager() - for file_state in file_states: # Skip files that are too large if file_state.size > MAX_OUTPUT_FILE_SIZE: @@ -162,47 +146,14 @@ class SandboxBashSession: file_content = vm.download_file(file_state.path) file_binary = file_content.getvalue() - # Determine mime type from extension filename = os.path.basename(file_state.path) - mime_type, _ = mimetypes.guess_type(filename) - if not mime_type: - mime_type = "application/octet-stream" - - # Save as ToolFile - tool_file = tool_file_manager.create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, - conversation_id=None, - file_binary=file_binary, - mimetype=mime_type, - filename=filename, - ) - - # Determine file type from mime type - file_type = _get_file_type_from_mime(mime_type) - extension = os.path.splitext(filename)[1] if "." in filename else ".bin" - url = sign_tool_file(tool_file.id, extension) - - # Create File object with tool_file_id as related_id - file_obj = File( - id=tool_file.id, # Use tool_file_id as the File id for easy reference - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=filename, - extension=extension, - mime_type=mime_type, - size=len(file_binary), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, - ) + file_obj = self._create_tool_file(filename=filename, file_binary=file_binary) collected_files.append(file_obj) logger.info( "Collected sandbox output file: %s -> tool_file_id=%s", file_state.path, - tool_file.id, + file_obj.id, ) except Exception as exc: @@ -216,6 +167,85 @@ class SandboxBashSession: ) return collected_files + def download_file(self, path: str) -> File: + path_kind = self._detect_path_kind(path) + if path_kind == "dir": + raise ValueError("Directory outputs are not supported") + if path_kind != "file": + raise ValueError(f"Sandbox file not found: {path}") + + file_content = self._sandbox.vm.download_file(path) + file_binary = file_content.getvalue() + if len(file_binary) > MAX_OUTPUT_FILE_SIZE: + raise ValueError(f"Sandbox file exceeds size limit: {path}") + + filename = os.path.basename(path) or "file" + return self._create_tool_file(filename=filename, file_binary=file_binary) + + def _detect_path_kind(self, path: str) -> str: + script = r""" +import os +import sys + +p = sys.argv[1] +if os.path.isdir(p): + print("dir") + raise SystemExit(0) +if os.path.isfile(p): + print("file") + raise SystemExit(0) +print("none") +raise SystemExit(2) +""" + try: + result = execute( + self._sandbox.vm, + [ + "sh", + "-c", + 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"', + script, + path, + ], + timeout=10, + error_message="Failed to inspect sandbox path", + ) + except CommandExecutionError as exc: + raise ValueError(str(exc)) from exc + return result.stdout.decode("utf-8", errors="replace").strip() + + def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File: + mime_type, _ = mimetypes.guess_type(filename) + if not mime_type: + mime_type = "application/octet-stream" + + tool_file = ToolFileManager().create_file_by_raw( + user_id=self._user_id, + tenant_id=self._tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mime_type, + filename=filename, + ) + + file_type = _get_file_type_from_mime(mime_type) + extension = os.path.splitext(filename)[1] if "." in filename else ".bin" + url = sign_tool_file(tool_file.id, extension) + + return File( + id=tool_file.id, + tenant_id=self._tenant_id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + filename=filename, + extension=extension, + mime_type=mime_type, + size=len(file_binary), + related_id=tool_file.id, + url=url, + storage_key=tool_file.file_key, + ) + def _get_file_type_from_mime(mime_type: str) -> FileType: """Determine FileType from mime type.""" diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 86bf8f473e..b1460ae402 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str: """ Build a text description of generated files for inclusion in context. - The description includes file_id which can be used by subsequent nodes - to reference the files via structured output. + The description includes file_id for context; structured output file paths + are only supported in sandbox mode. """ if not files: return "" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c4e5ef323b..4a8f96f09e 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -21,7 +21,11 @@ from core.app_assets.constants import AppAssetsAttrs from core.file import FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output +from core.llm_generator.output_parser.file_ref import ( + adapt_schema_for_sandbox_file_paths, + convert_sandbox_file_paths_in_output, + detect_file_path_fields, +) from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.memory.base import BaseMemory from core.model_manager import ModelInstance, ModelManager @@ -297,13 +301,20 @@ class LLMNode(Node[LLMNodeData]): ) structured_output_schema: Mapping[str, Any] | None + structured_output_file_paths: list[str] = [] if self.node_data.structured_output_enabled: if not self.node_data.structured_output: raise ValueError("structured_output_enabled is True but structured_output is not set") - structured_output_schema = LLMNode.fetch_structured_output_schema( - structured_output=self.node_data.structured_output - ) + raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output) + if self.node_data.computer_use: + structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths( + raw_schema + ) + else: + if detect_file_path_fields(raw_schema): + raise LLMNodeError("Structured output file paths are only supported in sandbox mode.") + structured_output_schema = raw_schema else: structured_output_schema = None @@ -319,8 +330,10 @@ class LLMNode(Node[LLMNodeData]): stop=stop, variable_pool=variable_pool, tool_dependencies=tool_dependencies, - structured_output_schema=structured_output_schema + structured_output_schema=structured_output_schema, + structured_output_file_paths=structured_output_file_paths, ) + elif self.tool_call_enabled: generator = self._invoke_llm_with_tools( model_instance=model_instance, @@ -330,7 +343,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, node_inputs=node_inputs, process_data=process_data, - structured_output_schema=structured_output_schema + structured_output_schema=structured_output_schema, ) else: # Use traditional LLM invocation @@ -532,8 +545,8 @@ class LLMNode(Node[LLMNodeData]): model_parameters=node_data_model.completion_params, stop=list(stop or []), user=user_id, - tenant_id=tenant_id, ) + else: request_start_time = time.perf_counter() @@ -1880,7 +1893,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, node_inputs: dict[str, Any], process_data: dict[str, Any], - structured_output_schema: Mapping[str, Any] | None + structured_output_schema: Mapping[str, Any] | None, ) -> Generator[NodeEventBase, None, LLMGenerationData]: """Invoke LLM with tools support (from Agent V2). @@ -1906,14 +1919,14 @@ class LLMNode(Node[LLMNodeData]): files=prompt_files, max_iterations=self._node_data.max_iterations or 10, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), - structured_output_schema=structured_output_schema + structured_output_schema=structured_output_schema, ) # Run strategy outputs = strategy.run( prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, - stop=list(stop or []) + stop=list(stop or []), ) result = yield from self._process_tool_outputs(outputs) @@ -1927,10 +1940,12 @@ class LLMNode(Node[LLMNodeData]): stop: Sequence[str] | None, variable_pool: VariablePool, tool_dependencies: ToolDependencies | None, - structured_output_schema: Mapping[str, Any] | None + structured_output_schema: Mapping[str, Any] | None, + structured_output_file_paths: Sequence[str] | None, ) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]: result: LLMGenerationData | None = None sandbox_output_files: list[File] = [] + structured_output_files: list[File] = [] # FIXME(Mairuis): Async processing for bash session. with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session: @@ -1948,17 +1963,31 @@ class LLMNode(Node[LLMNodeData]): max_iterations=self._node_data.max_iterations or 100, agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), - structured_output_schema=structured_output_schema + structured_output_schema=structured_output_schema, ) outputs = strategy.run( prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, - stop=list(stop or []) + stop=list(stop or []), ) result = yield from self._process_tool_outputs(outputs) + if result and result.structured_output and structured_output_file_paths: + structured_output_payload = result.structured_output.structured_output or {} + try: + converted_output, structured_output_files = convert_sandbox_file_paths_in_output( + output=structured_output_payload, + file_path_fields=structured_output_file_paths, + file_resolver=session.download_file, + ) + except ValueError as exc: + raise LLMNodeError(str(exc)) from exc + result = result.model_copy( + update={"structured_output": LLMStructuredOutput(structured_output=converted_output)} + ) + # Collect output files from sandbox before session ends # Files are saved as ToolFiles with valid tool_file_id for later reference sandbox_output_files = session.collect_output_files() @@ -1971,7 +2000,7 @@ class LLMNode(Node[LLMNodeData]): yield structured_output # Merge sandbox output files into result - if sandbox_output_files: + if sandbox_output_files or structured_output_files: result = LLMGenerationData( text=result.text, reasoning_contents=result.reasoning_contents, @@ -1979,7 +2008,7 @@ class LLMNode(Node[LLMNodeData]): sequence=result.sequence, usage=result.usage, finish_reason=result.finish_reason, - files=result.files + sandbox_output_files, + files=result.files + sandbox_output_files + structured_output_files, trace=result.trace, ) @@ -2056,9 +2085,12 @@ class LLMNode(Node[LLMNodeData]): structured_output=self._node_data.structured_output or {}, ) tool_instances.extend( - build_agent_output_tools(tenant_id=self.tenant_id, invoke_from=self.invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, - structured_output_schema=structured_output_schema) + build_agent_output_tools( + tenant_id=self.tenant_id, + invoke_from=self.invoke_from, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + structured_output_schema=structured_output_schema, + ) ) return tool_instances @@ -2564,15 +2596,7 @@ class LLMNode(Node[LLMNodeData]): raise ValueError("No agent result found in tool outputs") output_payload = agent_result.output if isinstance(output_payload, dict): - state.aggregate.structured_output = LLMStructuredOutput( - structured_output=convert_file_refs_in_output( - output=output_payload, - json_schema=LLMNode.fetch_structured_output_schema( - structured_output=self._node_data.structured_output or {}, - ), - tenant_id=self.tenant_id, - ) - ) + state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload) state.aggregate.text = json.dumps(output_payload) elif isinstance(output_payload, str): state.aggregate.text = output_payload diff --git a/api/tests/fixtures/file output schema.yml b/api/tests/fixtures/file output schema.yml index 37fc9c72c7..8ae3072572 100644 --- a/api/tests/fixtures/file output schema.yml +++ b/api/tests/fixtures/file output schema.yml @@ -126,9 +126,8 @@ workflow: additionalProperties: false properties: image: - description: File ID (UUID) of the selected image - format: dify-file-ref - type: string + description: Sandbox file path of the selected image + type: file required: - image type: object diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py index 6d18ac7fc9..2c16de8f6f 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py @@ -1,269 +1,93 @@ """ -Unit tests for file reference detection and conversion. +Unit tests for sandbox file path detection and conversion. """ -import uuid -from unittest.mock import MagicMock, patch - import pytest from core.file import File, FileTransferMethod, FileType from core.llm_generator.output_parser.file_ref import ( - FILE_REF_FORMAT, - convert_file_refs_in_output, - detect_file_ref_fields, - is_file_ref_property, + FILE_PATH_DESCRIPTION_SUFFIX, + adapt_schema_for_sandbox_file_paths, + convert_sandbox_file_paths_in_output, + detect_file_path_fields, + is_file_path_property, ) from core.variables.segments import ArrayFileSegment, FileSegment -class TestIsFileRefProperty: - """Tests for is_file_ref_property function.""" - - def test_valid_file_ref(self): - schema = {"type": "string", "format": FILE_REF_FORMAT} - assert is_file_ref_property(schema) is True - - def test_invalid_type(self): - schema = {"type": "number", "format": FILE_REF_FORMAT} - assert is_file_ref_property(schema) is False - - def test_missing_format(self): - schema = {"type": "string"} - assert is_file_ref_property(schema) is False - - def test_wrong_format(self): - schema = {"type": "string", "format": "uuid"} - assert is_file_ref_property(schema) is False +def _build_file(file_id: str) -> File: + return File( + id=file_id, + tenant_id="tenant_123", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + filename="test.png", + extension=".png", + mime_type="image/png", + size=128, + related_id=file_id, + storage_key="sandbox/path", + ) -class TestDetectFileRefFields: - """Tests for detect_file_ref_fields function.""" +class TestFilePathSchema: + def test_is_file_path_property(self): + assert is_file_path_property({"type": "file"}) is True + assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True + assert is_file_path_property({"type": "string"}) is False - def test_simple_file_ref(self): + def test_detect_file_path_fields(self): schema = { "type": "object", "properties": { - "image": {"type": "string", "format": FILE_REF_FORMAT}, + "image": {"type": "string", "format": "dify-file-ref"}, + "files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}}, + "meta": {"type": "object", "properties": {"doc": {"type": "file"}}}, }, } - paths = detect_file_ref_fields(schema) - assert paths == ["image"] + assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"} - def test_multiple_file_refs(self): + def test_adapt_schema_for_sandbox_file_paths(self): schema = { "type": "object", "properties": { - "image": {"type": "string", "format": FILE_REF_FORMAT}, - "document": {"type": "string", "format": FILE_REF_FORMAT}, + "image": {"type": "string", "format": "dify-file-ref"}, "name": {"type": "string"}, }, } - paths = detect_file_ref_fields(schema) - assert set(paths) == {"image", "document"} + adapted, fields = adapt_schema_for_sandbox_file_paths(schema) - def test_array_of_file_refs(self): - schema = { - "type": "object", - "properties": { - "files": { - "type": "array", - "items": {"type": "string", "format": FILE_REF_FORMAT}, - }, - }, - } - paths = detect_file_ref_fields(schema) - assert paths == ["files[*]"] - - def test_nested_file_ref(self): - schema = { - "type": "object", - "properties": { - "data": { - "type": "object", - "properties": { - "image": {"type": "string", "format": FILE_REF_FORMAT}, - }, - }, - }, - } - paths = detect_file_ref_fields(schema) - assert paths == ["data.image"] - - def test_no_file_refs(self): - schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "count": {"type": "number"}, - }, - } - paths = detect_file_ref_fields(schema) - assert paths == [] - - def test_empty_schema(self): - schema = {} - paths = detect_file_ref_fields(schema) - assert paths == [] - - def test_mixed_schema(self): - schema = { - "type": "object", - "properties": { - "query": {"type": "string"}, - "image": {"type": "string", "format": FILE_REF_FORMAT}, - "documents": { - "type": "array", - "items": {"type": "string", "format": FILE_REF_FORMAT}, - }, - }, - } - paths = detect_file_ref_fields(schema) - assert set(paths) == {"image", "documents[*]"} + assert set(fields) == {"image"} + adapted_image = adapted["properties"]["image"] + assert adapted_image["type"] == "string" + assert "format" not in adapted_image + assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"] -class TestConvertFileRefsInOutput: - """Tests for convert_file_refs_in_output function.""" +class TestConvertSandboxFilePaths: + def test_convert_sandbox_file_paths(self): + output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"} - @pytest.fixture - def mock_file(self): - """Create a mock File object with all required attributes.""" - file = MagicMock(spec=File) - file.type = FileType.IMAGE - file.transfer_method = FileTransferMethod.TOOL_FILE - file.related_id = "test-related-id" - file.remote_url = None - file.tenant_id = "tenant_123" - file.id = None - file.filename = "test.png" - file.extension = ".png" - file.mime_type = "image/png" - file.size = 1024 - file.dify_model_identity = "__dify__file__" - return file + def resolver(path: str) -> File: + return _build_file(path) - @pytest.fixture - def mock_build_from_mapping(self, mock_file): - """Mock the build_from_mapping function.""" - with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock: - mock.return_value = mock_file - yield mock + converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver) - def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file): - file_id = str(uuid.uuid4()) - output = {"image": file_id} - schema = { - "type": "object", - "properties": { - "image": {"type": "string", "format": FILE_REF_FORMAT}, - }, - } + assert isinstance(converted["image"], FileSegment) + assert isinstance(converted["files"], ArrayFileSegment) + assert converted["name"] == "demo" + assert [file.id for file in files] == ["a.png", "b.png", "c.png"] - result = convert_file_refs_in_output(output, schema, "tenant_123") + def test_invalid_path_value_raises(self): + def resolver(path: str) -> File: + return _build_file(path) - # Result should be wrapped in FileSegment - assert isinstance(result["image"], FileSegment) - assert result["image"].value == mock_file - mock_build_from_mapping.assert_called_once_with( - mapping={"transfer_method": "tool_file", "tool_file_id": file_id}, - tenant_id="tenant_123", - ) + with pytest.raises(ValueError): + convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver) - def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file): - file_id1 = str(uuid.uuid4()) - file_id2 = str(uuid.uuid4()) - output = {"files": [file_id1, file_id2]} - schema = { - "type": "object", - "properties": { - "files": { - "type": "array", - "items": {"type": "string", "format": FILE_REF_FORMAT}, - }, - }, - } + def test_no_file_paths_returns_output(self): + output = {"name": "demo"} + converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file) - result = convert_file_refs_in_output(output, schema, "tenant_123") - - # Result should be wrapped in ArrayFileSegment - assert isinstance(result["files"], ArrayFileSegment) - assert list(result["files"].value) == [mock_file, mock_file] - assert mock_build_from_mapping.call_count == 2 - - def test_no_conversion_without_file_refs(self): - output = {"name": "test", "count": 5} - schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "count": {"type": "number"}, - }, - } - - result = convert_file_refs_in_output(output, schema, "tenant_123") - - assert result == {"name": "test", "count": 5} - - def test_invalid_uuid_returns_none(self): - output = {"image": "not-a-valid-uuid"} - schema = { - "type": "object", - "properties": { - "image": {"type": "string", "format": FILE_REF_FORMAT}, - }, - } - - result = convert_file_refs_in_output(output, schema, "tenant_123") - - assert result["image"] is None - - def test_file_not_found_returns_none(self): - file_id = str(uuid.uuid4()) - output = {"image": file_id} - schema = { - "type": "object", - "properties": { - "image": {"type": "string", "format": FILE_REF_FORMAT}, - }, - } - - with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock: - mock.side_effect = ValueError("File not found") - result = convert_file_refs_in_output(output, schema, "tenant_123") - - assert result["image"] is None - - def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file): - file_id = str(uuid.uuid4()) - output = {"query": "search term", "image": file_id, "count": 10} - schema = { - "type": "object", - "properties": { - "query": {"type": "string"}, - "image": {"type": "string", "format": FILE_REF_FORMAT}, - "count": {"type": "number"}, - }, - } - - result = convert_file_refs_in_output(output, schema, "tenant_123") - - assert result["query"] == "search term" - assert isinstance(result["image"], FileSegment) - assert result["image"].value == mock_file - assert result["count"] == 10 - - def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file): - file_id = str(uuid.uuid4()) - original = {"image": file_id} - output = dict(original) - schema = { - "type": "object", - "properties": { - "image": {"type": "string", "format": FILE_REF_FORMAT}, - }, - } - - convert_file_refs_in_output(output, schema, "tenant_123") - - # Original should still contain the string ID - assert original["image"] == file_id + assert converted == output + assert files == [] diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index df73c29004..fd5cd77612 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -1,4 +1,6 @@ +from collections.abc import Mapping from decimal import Decimal +from typing import Any, NotRequired, TypedDict from unittest.mock import MagicMock, patch import pytest @@ -6,16 +8,13 @@ from pydantic import BaseModel, ConfigDict from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import ( - _get_default_value_for_type, + get_default_value_for_type, fill_defaults_from_schema, invoke_llm_with_pydantic_model, invoke_llm_with_structured_output, ) from core.model_runtime.entities.llm_entities import ( LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, LLMUsage, ) @@ -25,7 +24,30 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelType, + ParameterRule, + ParameterType, +) + + +class StructuredOutputTestCase(TypedDict): + name: str + provider: str + model_name: str + support_structure_output: bool + stream: bool + json_schema: Mapping[str, Any] + expected_llm_response: LLMResult + expected_result_type: type[LLMResultWithStructuredOutput] | None + should_raise: bool + expected_error: NotRequired[type[OutputParserError]] + parameter_rules: NotRequired[list[ParameterRule]] + + +SchemaData = dict[str, Any] def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: @@ -71,7 +93,7 @@ def get_model_instance() -> MagicMock: def test_structured_output_parser(): """Test cases for invoke_llm_with_structured_output function""" - testcases = [ + testcases: list[StructuredOutputTestCase] = [ # Test case 1: Model with native structured output support, non-streaming { "name": "native_structured_output_non_streaming", @@ -88,39 +110,6 @@ def test_structured_output_parser(): "expected_result_type": LLMResultWithStructuredOutput, "should_raise": False, }, - # Test case 2: Model with native structured output support, streaming - { - "name": "native_structured_output_streaming", - "provider": "openai", - "model_name": "gpt-4o", - "support_structure_output": True, - "stream": True, - "json_schema": {"type": "object", "properties": {"name": {"type": "string"}}}, - "expected_llm_response": [ - LLMResultChunk( - model="gpt-4o", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content='{"name":'), - usage=create_mock_usage(prompt_tokens=10, completion_tokens=2), - ), - ), - LLMResultChunk( - model="gpt-4o", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=' "test"}'), - usage=create_mock_usage(prompt_tokens=10, completion_tokens=3), - ), - ), - ], - "expected_result_type": "generator", - "should_raise": False, - }, # Test case 3: Model without native structured output support, non-streaming { "name": "prompt_based_structured_output_non_streaming", @@ -137,78 +126,24 @@ def test_structured_output_parser(): "expected_result_type": LLMResultWithStructuredOutput, "should_raise": False, }, - # Test case 4: Model without native structured output support, streaming { - "name": "prompt_based_structured_output_streaming", - "provider": "anthropic", - "model_name": "claude-3-sonnet", - "support_structure_output": False, - "stream": True, - "json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}}, - "expected_llm_response": [ - LLMResultChunk( - model="claude-3-sonnet", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content='{"answer": "test'), - usage=create_mock_usage(prompt_tokens=15, completion_tokens=3), - ), - ), - LLMResultChunk( - model="claude-3-sonnet", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=' response"}'), - usage=create_mock_usage(prompt_tokens=15, completion_tokens=5), - ), - ), - ], - "expected_result_type": "generator", - "should_raise": False, - }, - # Test case 5: Streaming with list content - { - "name": "streaming_with_list_content", + "name": "non_streaming_with_list_content", "provider": "openai", "model_name": "gpt-4o", "support_structure_output": True, - "stream": True, + "stream": False, "json_schema": {"type": "object", "properties": {"data": {"type": "string"}}}, - "expected_llm_response": [ - LLMResultChunk( - model="gpt-4o", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=[ - TextPromptMessageContent(data='{"data":'), - ] - ), - usage=create_mock_usage(prompt_tokens=10, completion_tokens=2), - ), + "expected_llm_response": LLMResult( + model="gpt-4o", + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data='{"data":'), + TextPromptMessageContent(data=' "value"}'), + ] ), - LLMResultChunk( - model="gpt-4o", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=[ - TextPromptMessageContent(data=' "value"}'), - ] - ), - usage=create_mock_usage(prompt_tokens=10, completion_tokens=3), - ), - ), - ], - "expected_result_type": "generator", + usage=create_mock_usage(prompt_tokens=10, completion_tokens=5), + ), + "expected_result_type": LLMResultWithStructuredOutput, "should_raise": False, }, # Test case 6: Error case - non-string LLM response content (non-streaming) @@ -253,7 +188,13 @@ def test_structured_output_parser(): "stream": False, "json_schema": {"type": "object", "properties": {"result": {"type": "string"}}}, "parameter_rules": [ - MagicMock(name="response_format", options=["json_schema"], required=False), + ParameterRule( + name="response_format", + label=I18nObject(en_US="response_format"), + type=ParameterType.STRING, + required=False, + options=["json_schema"], + ), ], "expected_llm_response": LLMResult( model="gpt-4o", @@ -272,7 +213,13 @@ def test_structured_output_parser(): "stream": False, "json_schema": {"type": "object", "properties": {"output": {"type": "string"}}}, "parameter_rules": [ - MagicMock(name="response_format", options=["JSON"], required=False), + ParameterRule( + name="response_format", + label=I18nObject(en_US="response_format"), + type=ParameterType.STRING, + required=False, + options=["JSON"], + ), ], "expected_llm_response": LLMResult( model="claude-3-sonnet", @@ -285,89 +232,72 @@ def test_structured_output_parser(): ] for case in testcases: - # Setup model entity - model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"]) + provider = case["provider"] + model_name = case["model_name"] + support_structure_output = case["support_structure_output"] + json_schema = case["json_schema"] + stream = case["stream"] - # Add parameter rules if specified - if "parameter_rules" in case: - model_schema.parameter_rules = case["parameter_rules"] + model_schema = get_model_entity(provider, model_name, support_structure_output) + + parameter_rules = case.get("parameter_rules") + if parameter_rules is not None: + model_schema.parameter_rules = parameter_rules - # Setup model instance model_instance = get_model_instance() model_instance.invoke_llm.return_value = case["expected_llm_response"] - # Setup prompt messages prompt_messages = [ SystemPromptMessage(content="You are a helpful assistant."), UserPromptMessage(content="Generate a response according to the schema."), ] if case["should_raise"]: - # Test error cases - with pytest.raises(case["expected_error"]): # noqa: PT012 - if case["stream"]: + expected_error = case.get("expected_error", OutputParserError) + with pytest.raises(expected_error): # noqa: PT012 + if stream: result_generator = invoke_llm_with_structured_output( - provider=case["provider"], + provider=provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=case["json_schema"], + json_schema=json_schema, ) - # Consume the generator to trigger the error list(result_generator) else: invoke_llm_with_structured_output( - provider=case["provider"], + provider=provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=case["json_schema"], + json_schema=json_schema, ) else: - # Test successful cases with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair: - # Configure json_repair mock for cases that need it if case["name"] == "json_repair_scenario": mock_json_repair.return_value = {"name": "test"} result = invoke_llm_with_structured_output( - provider=case["provider"], + provider=provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=case["json_schema"], + json_schema=json_schema, model_parameters={"temperature": 0.7, "max_tokens": 100}, user="test_user", ) - if case["expected_result_type"] == "generator": - # Test streaming results - assert hasattr(result, "__iter__") - chunks = list(result) - assert len(chunks) > 0 + expected_result_type = case["expected_result_type"] + assert expected_result_type is not None + assert isinstance(result, expected_result_type) + assert result.model == model_name + assert result.structured_output is not None + assert isinstance(result.structured_output, dict) - # Verify all chunks are LLMResultChunkWithStructuredOutput - for chunk in chunks[:-1]: # All except last - assert isinstance(chunk, LLMResultChunkWithStructuredOutput) - assert chunk.model == case["model_name"] - - # Last chunk should have structured output - last_chunk = chunks[-1] - assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput) - assert last_chunk.structured_output is not None - assert isinstance(last_chunk.structured_output, dict) - else: - # Test non-streaming results - assert isinstance(result, case["expected_result_type"]) - assert result.model == case["model_name"] - assert result.structured_output is not None - assert isinstance(result.structured_output, dict) - - # Verify model_instance.invoke_llm was called with correct parameters model_instance.invoke_llm.assert_called_once() call_args = model_instance.invoke_llm.call_args - assert call_args.kwargs["stream"] == case["stream"] + assert call_args.kwargs["stream"] == stream assert call_args.kwargs["user"] == "test_user" assert "temperature" in call_args.kwargs["model_parameters"] assert "max_tokens" in call_args.kwargs["model_parameters"] @@ -376,45 +306,32 @@ def test_structured_output_parser(): def test_parse_structured_output_edge_cases(): """Test edge cases for structured output parsing""" - # Test case with list that contains dict (reasoning model scenario) - testcase_list_with_dict = { - "name": "list_with_dict_parsing", - "provider": "deepseek", - "model_name": "deepseek-r1", - "support_structure_output": False, - "stream": False, - "json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}}, - "expected_llm_response": LLMResult( - model="deepseek-r1", - message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'), - usage=create_mock_usage(prompt_tokens=10, completion_tokens=5), - ), - "expected_result_type": LLMResultWithStructuredOutput, - "should_raise": False, - } - - # Setup for list parsing test - model_schema = get_model_entity( - testcase_list_with_dict["provider"], - testcase_list_with_dict["model_name"], - testcase_list_with_dict["support_structure_output"], + provider = "deepseek" + model_name = "deepseek-r1" + support_structure_output = False + json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}} + expected_llm_response = LLMResult( + model="deepseek-r1", + message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'), + usage=create_mock_usage(prompt_tokens=10, completion_tokens=5), ) + model_schema = get_model_entity(provider, model_name, support_structure_output) + model_instance = get_model_instance() - model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"] + model_instance.invoke_llm.return_value = expected_llm_response prompt_messages = [UserPromptMessage(content="Test reasoning")] with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair: - # Mock json_repair to return a list with dict mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"] result = invoke_llm_with_structured_output( - provider=testcase_list_with_dict["provider"], + provider=provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=testcase_list_with_dict["json_schema"], + json_schema=json_schema, ) assert isinstance(result, LLMResultWithStructuredOutput) @@ -424,18 +341,16 @@ def test_parse_structured_output_edge_cases(): def test_model_specific_schema_preparation(): """Test schema preparation for different model types""" - # Test Gemini model - gemini_case = { - "provider": "google", - "model_name": "gemini-pro", - "support_structure_output": True, - "stream": False, - "json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False}, + provider = "google" + model_name = "gemini-pro" + support_structure_output = True + json_schema: SchemaData = { + "type": "object", + "properties": {"result": {"type": "boolean"}}, + "additionalProperties": False, } - model_schema = get_model_entity( - gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"] - ) + model_schema = get_model_entity(provider, model_name, support_structure_output) model_instance = get_model_instance() model_instance.invoke_llm.return_value = LLMResult( @@ -447,11 +362,11 @@ def test_model_specific_schema_preparation(): prompt_messages = [UserPromptMessage(content="Test")] result = invoke_llm_with_structured_output( - provider=gemini_case["provider"], + provider=provider, model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=gemini_case["json_schema"], + json_schema=json_schema, ) assert isinstance(result, LLMResultWithStructuredOutput) @@ -493,40 +408,26 @@ def test_structured_output_with_pydantic_model_non_streaming(): assert result.name == "test" -def test_structured_output_with_pydantic_model_streaming(): +def test_structured_output_with_pydantic_model_list_content(): model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True) model_instance = get_model_instance() - - def mock_streaming_response(): - yield LLMResultChunk( - model="gpt-4o", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content='{"name":'), - usage=create_mock_usage(prompt_tokens=8, completion_tokens=2), - ), - ) - yield LLMResultChunk( - model="gpt-4o", - prompt_messages=[UserPromptMessage(content="test")], - system_fingerprint="test", - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=' "test"}'), - usage=create_mock_usage(prompt_tokens=8, completion_tokens=4), - ), - ) - - model_instance.invoke_llm.return_value = mock_streaming_response() + model_instance.invoke_llm.return_value = LLMResult( + model="gpt-4o", + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data='{"name":'), + TextPromptMessageContent(data=' "test"}'), + ] + ), + usage=create_mock_usage(prompt_tokens=8, completion_tokens=4), + ) result = invoke_llm_with_pydantic_model( provider="openai", model_schema=model_schema, model_instance=model_instance, prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")], - output_model=ExampleOutput + output_model=ExampleOutput, ) assert isinstance(result, ExampleOutput) @@ -548,51 +449,51 @@ def test_structured_output_with_pydantic_model_validation_error(): model_schema=model_schema, model_instance=model_instance, prompt_messages=[UserPromptMessage(content="test")], - output_model=ExampleOutput + output_model=ExampleOutput, ) class TestGetDefaultValueForType: - """Test cases for _get_default_value_for_type function""" + """Test cases for get_default_value_for_type function""" def test_string_type(self): - assert _get_default_value_for_type("string") == "" + assert get_default_value_for_type("string") == "" def test_object_type(self): - assert _get_default_value_for_type("object") == {} + assert get_default_value_for_type("object") == {} def test_array_type(self): - assert _get_default_value_for_type("array") == [] + assert get_default_value_for_type("array") == [] def test_number_type(self): - assert _get_default_value_for_type("number") == 0 + assert get_default_value_for_type("number") == 0 def test_integer_type(self): - assert _get_default_value_for_type("integer") == 0 + assert get_default_value_for_type("integer") == 0 def test_boolean_type(self): - assert _get_default_value_for_type("boolean") is False + assert get_default_value_for_type("boolean") is False def test_null_type(self): - assert _get_default_value_for_type("null") is None + assert get_default_value_for_type("null") is None def test_none_type(self): - assert _get_default_value_for_type(None) is None + assert get_default_value_for_type(None) is None def test_unknown_type(self): - assert _get_default_value_for_type("unknown") is None + assert get_default_value_for_type("unknown") is None def test_union_type_string_null(self): # ["string", "null"] should return "" (first non-null type) - assert _get_default_value_for_type(["string", "null"]) == "" + assert get_default_value_for_type(["string", "null"]) == "" def test_union_type_null_first(self): # ["null", "integer"] should return 0 (first non-null type) - assert _get_default_value_for_type(["null", "integer"]) == 0 + assert get_default_value_for_type(["null", "integer"]) == 0 def test_union_type_only_null(self): # ["null"] should return None - assert _get_default_value_for_type(["null"]) is None + assert get_default_value_for_type(["null"]) is None class TestFillDefaultsFromSchema: @@ -600,7 +501,7 @@ class TestFillDefaultsFromSchema: def test_simple_required_fields(self): """Test filling simple required fields""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "name": {"type": "string"}, @@ -609,7 +510,7 @@ class TestFillDefaultsFromSchema: }, "required": ["name", "age"], } - output = {"name": "Alice"} + output: SchemaData = {"name": "Alice"} result = fill_defaults_from_schema(output, schema) @@ -619,7 +520,7 @@ class TestFillDefaultsFromSchema: def test_non_required_fields_not_filled(self): """Test that non-required fields are not filled""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "required_field": {"type": "string"}, @@ -627,7 +528,7 @@ class TestFillDefaultsFromSchema: }, "required": ["required_field"], } - output = {} + output: SchemaData = {} result = fill_defaults_from_schema(output, schema) @@ -636,7 +537,7 @@ class TestFillDefaultsFromSchema: def test_nested_object_required_fields(self): """Test filling nested object required fields""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "user": { @@ -659,7 +560,7 @@ class TestFillDefaultsFromSchema: }, "required": ["user"], } - output = { + output: SchemaData = { "user": { "name": "Alice", "address": { @@ -684,7 +585,7 @@ class TestFillDefaultsFromSchema: def test_missing_nested_object_created(self): """Test that missing required nested objects are created""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "metadata": { @@ -698,7 +599,7 @@ class TestFillDefaultsFromSchema: }, "required": ["metadata"], } - output = {} + output: SchemaData = {} result = fill_defaults_from_schema(output, schema) @@ -710,7 +611,7 @@ class TestFillDefaultsFromSchema: def test_all_types_default_values(self): """Test default values for all types""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "str_field": {"type": "string"}, @@ -722,7 +623,7 @@ class TestFillDefaultsFromSchema: }, "required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"], } - output = {} + output: SchemaData = {} result = fill_defaults_from_schema(output, schema) @@ -737,7 +638,7 @@ class TestFillDefaultsFromSchema: def test_existing_values_preserved(self): """Test that existing values are not overwritten""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "name": {"type": "string"}, @@ -745,7 +646,7 @@ class TestFillDefaultsFromSchema: }, "required": ["name", "count"], } - output = {"name": "Bob", "count": 42} + output: SchemaData = {"name": "Bob", "count": 42} result = fill_defaults_from_schema(output, schema) @@ -753,7 +654,7 @@ class TestFillDefaultsFromSchema: def test_complex_nested_structure(self): """Test complex nested structure with multiple levels""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "user": { @@ -789,7 +690,7 @@ class TestFillDefaultsFromSchema: }, "required": ["user", "tags", "metadata", "is_active"], } - output = { + output: SchemaData = { "user": { "name": "Alice", "age": 25, @@ -829,8 +730,8 @@ class TestFillDefaultsFromSchema: def test_empty_schema(self): """Test with empty schema""" - schema = {} - output = {"any": "value"} + schema: SchemaData = {} + output: SchemaData = {"any": "value"} result = fill_defaults_from_schema(output, schema) @@ -838,14 +739,14 @@ class TestFillDefaultsFromSchema: def test_schema_without_required(self): """Test schema without required field""" - schema = { + schema: SchemaData = { "type": "object", "properties": { "optional1": {"type": "string"}, "optional2": {"type": "integer"}, }, } - output = {} + output: SchemaData = {} result = fill_defaults_from_schema(output, schema)