mirror of https://github.com/langgenius/dify.git
feat: implement strict type validation for remote file uploads (#27010)
This commit is contained in:
parent
598dd1f816
commit
e4b5b0e5fd
|
|
@ -166,7 +166,10 @@ def _build_from_local_file(
|
|||
if strict_type_validation and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
|
@ -214,9 +217,10 @@ def _build_from_remote_url(
|
|||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = (
|
||||
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
|
||||
)
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
|
@ -238,10 +242,17 @@ def _build_from_remote_url(
|
|||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin")
|
||||
|
||||
file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
if file_type.value != mapping.get("type", "custom"):
|
||||
detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type)
|
||||
specified_type = mapping.get("type")
|
||||
|
||||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
|
|
@ -331,7 +342,10 @@ def _build_from_tool_file(
|
|||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
|
|
@ -376,7 +390,10 @@ def _build_from_datasource_file(
|
|||
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
|
||||
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
|
||||
|
||||
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
|
||||
if specified_type and specified_type != "custom":
|
||||
file_type = FileType(specified_type)
|
||||
else:
|
||||
file_type = detected_file_type
|
||||
|
||||
return File(
|
||||
id=mapping.get("datasource_file_id"),
|
||||
|
|
|
|||
|
|
@ -150,6 +150,42 @@ def test_build_from_remote_url(mock_http_head):
|
|||
assert file.size == 2048
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_type", "should_pass", "expected_error"),
|
||||
[
|
||||
("image", True, None),
|
||||
("document", False, "Detected file type does not match the specified type"),
|
||||
("video", False, "Detected file type does not match the specified type"),
|
||||
],
|
||||
)
|
||||
def test_build_from_remote_url_strict_validation(mock_http_head, file_type, should_pass, expected_error):
|
||||
"""Test strict type validation for remote_url."""
|
||||
mapping = {
|
||||
"transfer_method": "remote_url",
|
||||
"url": TEST_REMOTE_URL,
|
||||
"type": file_type,
|
||||
}
|
||||
if should_pass:
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
assert file.type == FileType(file_type)
|
||||
else:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
|
||||
|
||||
def test_build_from_remote_url_without_strict_validation(mock_http_head):
|
||||
"""Test that remote_url allows type mismatch when strict_type_validation is False."""
|
||||
mapping = {
|
||||
"transfer_method": "remote_url",
|
||||
"url": TEST_REMOTE_URL,
|
||||
"type": "document",
|
||||
}
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=False)
|
||||
assert file.transfer_method == FileTransferMethod.REMOTE_URL
|
||||
assert file.type == FileType.DOCUMENT
|
||||
assert file.filename == "remote_test.jpg"
|
||||
|
||||
|
||||
def test_tool_file_not_found():
|
||||
"""Test ToolFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue