diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 231e805460..2316e45179 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -166,7 +166,10 @@ def _build_from_local_file( if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("id"), @@ -214,9 +217,10 @@ def _build_from_remote_url( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type - ) + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("id"), @@ -238,10 +242,17 @@ def _build_from_remote_url( mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - file_type = _standardize_file_type(extension=extension, mime_type=mime_type) - if file_type.value != mapping.get("type", "custom"): + detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) + specified_type = mapping.get("type") + + if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type + return File( id=mapping.get("id"), filename=filename, @@ -331,7 +342,10 @@ def _build_from_tool_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("id"), @@ -376,7 +390,10 @@ def _build_from_datasource_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type + if specified_type and specified_type != "custom": + file_type = FileType(specified_type) + else: + file_type = detected_file_type return File( id=mapping.get("datasource_file_id"), diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 39280c9267..77c4956c04 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -150,6 +150,42 @@ def test_build_from_remote_url(mock_http_head): assert file.size == 2048 +@pytest.mark.parametrize( + ("file_type", "should_pass", "expected_error"), + [ + ("image", True, None), + ("document", False, "Detected file type does not match the specified type"), + ("video", False, "Detected file type does not match the specified type"), + ], +) +def test_build_from_remote_url_strict_validation(mock_http_head, file_type, should_pass, expected_error): + """Test strict type validation for remote_url.""" + mapping = { + "transfer_method": "remote_url", + "url": TEST_REMOTE_URL, + "type": file_type, + } + if should_pass: + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + assert file.type == FileType(file_type) + else: + with pytest.raises(ValueError, match=expected_error): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True) + + +def test_build_from_remote_url_without_strict_validation(mock_http_head): + """Test that remote_url allows type mismatch when strict_type_validation is False.""" + mapping = { + "transfer_method": "remote_url", + "url": TEST_REMOTE_URL, + "type": "document", + } + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=False) + assert file.transfer_method == FileTransferMethod.REMOTE_URL + assert file.type == FileType.DOCUMENT + assert file.filename == "remote_test.jpg" + + def test_tool_file_not_found(): """Test ToolFile not found in database.""" with patch("factories.file_factory.db.session.scalar", return_value=None):