mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
feat: structured output support file type
This commit is contained in:
parent
4f79d09d7b
commit
9b961fb41e
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
@ -0,0 +1,188 @@
|
||||
"""
|
||||
File reference detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||
2. Convert file ID strings to File objects after LLM returns
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
FILE_REF_FORMAT = "dify-file-ref"
|
||||
|
||||
|
||||
def is_file_ref_property(schema: dict) -> bool:
|
||||
"""Check if a schema property is a file reference."""
|
||||
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||
|
||||
|
||||
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""
|
||||
Recursively detect file reference fields in schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema to analyze
|
||||
path: Current path in the schema (used for recursion)
|
||||
|
||||
Returns:
|
||||
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||
"""
|
||||
file_ref_paths: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_ref_property(prop_schema):
|
||||
file_ref_paths.append(current_path)
|
||||
elif isinstance(prop_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_ref_property(items_schema):
|
||||
file_ref_paths.append(array_path)
|
||||
elif isinstance(items_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||
|
||||
return file_ref_paths
|
||||
|
||||
|
||||
def convert_file_refs_in_output(
|
||||
output: Mapping[str, Any],
|
||||
json_schema: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert file ID strings to File objects based on schema.
|
||||
|
||||
Args:
|
||||
output: The structured_output from LLM result
|
||||
json_schema: The original JSON schema (to detect file ref fields)
|
||||
tenant_id: Tenant ID for file lookup
|
||||
|
||||
Returns:
|
||||
Output with file references converted to File objects
|
||||
"""
|
||||
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||
if not file_ref_paths:
|
||||
return dict(output)
|
||||
|
||||
result = _deep_copy_dict(output)
|
||||
|
||||
for path in file_ref_paths:
|
||||
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Deep copy a mapping to a mutable dict."""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
result[key] = _deep_copy_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
# Handle array notation like "files[*]"
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else None
|
||||
target = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target, list):
|
||||
if remaining:
|
||||
# Nested array with remaining path - recurse into each item
|
||||
for item in target:
|
||||
if isinstance(item, dict):
|
||||
_convert_path_in_place(item, remaining, tenant_id)
|
||||
else:
|
||||
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||
files: list[File] = []
|
||||
for item in target:
|
||||
file = _convert_file_id(item, tenant_id)
|
||||
if file is not None:
|
||||
files.append(file)
|
||||
# Replace the array with ArrayFileSegment
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
# Leaf node - convert the value and wrap in FileSegment
|
||||
if current in obj:
|
||||
file = _convert_file_id(obj[current], tenant_id)
|
||||
if file is not None:
|
||||
obj[current] = FileSegment(value=file)
|
||||
else:
|
||||
obj[current] = None
|
||||
else:
|
||||
# Recurse into nested object
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||
|
||||
|
||||
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||
"""
|
||||
Convert a file ID string to a File object.
|
||||
|
||||
Tries multiple file sources in order:
|
||||
1. ToolFile (files generated by tools/workflows)
|
||||
2. UploadFile (files uploaded by users)
|
||||
"""
|
||||
if not isinstance(file_id, str):
|
||||
return None
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Try ToolFile first (files generated by tools/workflows)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try UploadFile (files uploaded by users)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# File not found in any source
|
||||
return None
|
||||
@ -8,6 +8,7 @@ import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@ -57,6 +58,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@ -72,6 +74,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@ -87,6 +90,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
@ -101,20 +105,28 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
Invoke large language model with structured output.
|
||||
|
||||
This method invokes model_instance.invoke_llm with json_schema and parses
|
||||
the result as structured output.
|
||||
|
||||
:param provider: model provider name
|
||||
:param model_schema: model schema entity
|
||||
:param model_instance: model instance to invoke
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param json_schema: json schema for structured output
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
@ -153,8 +165,18 @@ def invoke_llm_with_structured_output(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(llm_result.message.content)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
@ -186,8 +208,18 @@ def invoke_llm_with_structured_output(
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(result_text)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
structured_output=structured_output,
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
|
||||
@ -20,6 +20,7 @@ from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
@ -274,6 +275,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@ -404,6 +406,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@ -427,6 +430,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
@ -612,11 +616,39 @@ class LLMNode(Node[LLMNodeData]):
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
|
||||
context_messages: list[PromptMessage] = [
|
||||
LLMNode._truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
@staticmethod
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, truncating multi-modal base64 data
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
|
||||
181
api/tests/fixtures/file output schema.yml
vendored
Normal file
181
api/tests/fixtures/file output schema.yml
vendored
Normal file
@ -0,0 +1,181 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: file output schema
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
enabled: true
|
||||
fileUploadConfig:
|
||||
attachment_image_file_size_limit: 2
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
file_upload_limit: 10
|
||||
image_file_batch_limit: 10
|
||||
image_file_size_limit: 10
|
||||
single_chunk_attachment_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1768292241666-llm
|
||||
source: '1768292241666'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
sourceType: llm
|
||||
targetType: answer
|
||||
id: llm-answer
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1768292241666'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
query_prompt_template: '{{#sys.query#}}
|
||||
|
||||
|
||||
{{#sys.files#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: false
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
prompt_template:
|
||||
- id: e30d75d7-7d85-49ec-be3c-3baf7f6d3c5a
|
||||
role: system
|
||||
text: ''
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
image:
|
||||
description: File ID (UUID) of the selected image
|
||||
format: dify-file-ref
|
||||
type: string
|
||||
required:
|
||||
- image
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
configs:
|
||||
detail: high
|
||||
variable_selector:
|
||||
- sys
|
||||
- files
|
||||
enabled: true
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '{{#llm.structured_output.image#}}'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
variables: []
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 680
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 680
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: -149
|
||||
y: 97.5
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
@ -0,0 +1,269 @@
|
||||
"""
|
||||
Unit tests for file reference detection and conversion.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
FILE_REF_FORMAT,
|
||||
convert_file_refs_in_output,
|
||||
detect_file_ref_fields,
|
||||
is_file_ref_property,
|
||||
)
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestIsFileRefProperty:
|
||||
"""Tests for is_file_ref_property function."""
|
||||
|
||||
def test_valid_file_ref(self):
|
||||
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is True
|
||||
|
||||
def test_invalid_type(self):
|
||||
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_missing_format(self):
|
||||
schema = {"type": "string"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_wrong_format(self):
|
||||
schema = {"type": "string", "format": "uuid"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
|
||||
class TestDetectFileRefFields:
|
||||
"""Tests for detect_file_ref_fields function."""
|
||||
|
||||
def test_simple_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["image"]
|
||||
|
||||
def test_multiple_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "document"}
|
||||
|
||||
def test_array_of_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["files[*]"]
|
||||
|
||||
def test_nested_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["data.image"]
|
||||
|
||||
def test_no_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_empty_schema(self):
|
||||
schema = {}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_mixed_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "documents[*]"}
|
||||
|
||||
|
||||
class TestConvertFileRefsInOutput:
|
||||
"""Tests for convert_file_refs_in_output function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create a mock File object with all required attributes."""
|
||||
file = MagicMock(spec=File)
|
||||
file.type = FileType.IMAGE
|
||||
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
file.related_id = "test-related-id"
|
||||
file.remote_url = None
|
||||
file.tenant_id = "tenant_123"
|
||||
file.id = None
|
||||
file.filename = "test.png"
|
||||
file.extension = ".png"
|
||||
file.mime_type = "image/png"
|
||||
file.size = 1024
|
||||
file.dify_model_identity = "__dify__file__"
|
||||
return file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_build_from_mapping(self, mock_file):
|
||||
"""Mock the build_from_mapping function."""
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.return_value = mock_file
|
||||
yield mock
|
||||
|
||||
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in FileSegment
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
mock_build_from_mapping.assert_called_once_with(
|
||||
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||
tenant_id="tenant_123",
|
||||
)
|
||||
|
||||
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
output = {"files": [file_id1, file_id2]}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in ArrayFileSegment
|
||||
assert isinstance(result["files"], ArrayFileSegment)
|
||||
assert list(result["files"].value) == [mock_file, mock_file]
|
||||
assert mock_build_from_mapping.call_count == 2
|
||||
|
||||
def test_no_conversion_without_file_refs(self):
|
||||
output = {"name": "test", "count": 5}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result == {"name": "test", "count": 5}
|
||||
|
||||
def test_invalid_uuid_returns_none(self):
|
||||
output = {"image": "not-a-valid-uuid"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_file_not_found_returns_none(self):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.side_effect = ValueError("File not found")
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"query": "search term", "image": file_id, "count": 10}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["query"] == "search term"
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
assert result["count"] == 10
|
||||
|
||||
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
original = {"image": file_id}
|
||||
output = dict(original)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Original should still contain the string ID
|
||||
assert original["image"] == file_id
|
||||
Loading…
Reference in New Issue
Block a user