mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 18:24:09 +08:00
feat(api): support embedded Excel images in knowledge import (#37104)
This commit is contained in:
parent
759b4cbad3
commit
813bfea730
@ -316,6 +316,7 @@ class IndexingRunner:
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
deleted_preview_images = False
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
@ -368,6 +369,10 @@ class IndexingRunner:
|
||||
upload_file_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
deleted_preview_images = True
|
||||
|
||||
if deleted_preview_images:
|
||||
db.session.commit()
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
@ -1,13 +1,32 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
"""Excel document extractor used for RAG ingestion.
|
||||
|
||||
Supports cell hyperlinks for both `.xls` and `.xlsx`, and embedded worksheet images
|
||||
for `.xlsx` files by converting them into markdown image links. Embedded images are
|
||||
stored with deterministic keys derived from the source upload file and anchor cell so
|
||||
retries can safely reuse the same assets.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import TypedDict, override
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Candidate(TypedDict):
|
||||
@ -16,17 +35,42 @@ class Candidate(TypedDict):
|
||||
map: dict[int, str]
|
||||
|
||||
|
||||
class SheetImageCandidate(TypedDict):
|
||||
anchor: tuple[int, int]
|
||||
content_hash: str
|
||||
file_key: str
|
||||
image_bytes: bytes
|
||||
image_ext: str
|
||||
|
||||
|
||||
class ExcelExtractor(BaseExtractor):
|
||||
"""Load Excel files.
|
||||
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, encoding: str | None = None, autodetect_encoding: bool = False):
|
||||
_file_path: str
|
||||
_encoding: str | None
|
||||
_autodetect_encoding: bool
|
||||
_tenant_id: str | None
|
||||
_user_id: str | None
|
||||
_source_file_id: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
source_file_id: str | None = None,
|
||||
encoding: str | None = None,
|
||||
autodetect_encoding: bool = False,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._source_file_id = source_file_id
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
@ -37,7 +81,8 @@ class ExcelExtractor(BaseExtractor):
|
||||
file_extension = os.path.splitext(self._file_path)[-1].lower()
|
||||
|
||||
if file_extension == ".xlsx":
|
||||
wb = load_workbook(self._file_path, read_only=True, data_only=True)
|
||||
# Worksheet drawing objects, including embedded images, are not available in read-only mode.
|
||||
wb = load_workbook(self._file_path, data_only=True)
|
||||
try:
|
||||
for sheet_name in wb.sheetnames:
|
||||
sheet = wb[sheet_name]
|
||||
@ -45,10 +90,15 @@ class ExcelExtractor(BaseExtractor):
|
||||
if not column_map:
|
||||
continue
|
||||
start_row = header_row_idx + 1
|
||||
sheet_image_map = self._extract_images_from_sheet(
|
||||
sheet_name=sheet_name,
|
||||
sheet=sheet,
|
||||
valid_columns={column_idx + 1 for column_idx in column_map},
|
||||
min_row=start_row,
|
||||
)
|
||||
for row in sheet.iter_rows(min_row=start_row, max_col=max_col_idx, values_only=False):
|
||||
if all(cell.value is None for cell in row):
|
||||
continue
|
||||
page_content = []
|
||||
row_has_content = False
|
||||
for col_idx, cell in enumerate(row):
|
||||
value = cell.value
|
||||
if col_idx in column_map:
|
||||
@ -56,14 +106,27 @@ class ExcelExtractor(BaseExtractor):
|
||||
if hasattr(cell, "hyperlink") and cell.hyperlink:
|
||||
target = getattr(cell.hyperlink, "target", None)
|
||||
if target:
|
||||
value = f"[{value}]({target})"
|
||||
display_value = value if value is not None and str(value).strip() else target
|
||||
value = f"[{display_value}]({target})"
|
||||
cell_row = getattr(cell, "row", None)
|
||||
cell_column = getattr(cell, "column", None)
|
||||
image_links = (
|
||||
sheet_image_map.get((cell_row, cell_column), [])
|
||||
if isinstance(cell_row, int) and isinstance(cell_column, int)
|
||||
else []
|
||||
)
|
||||
if value is None:
|
||||
value = ""
|
||||
elif not isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.strip().replace('"', '\\"')
|
||||
if image_links:
|
||||
value = " ".join(filter(None, [value, " ".join(image_links)]))
|
||||
value = value.strip()
|
||||
if value:
|
||||
row_has_content = True
|
||||
value = value.replace('"', '\\"')
|
||||
page_content.append(f'"{col_name}":"{value}"')
|
||||
if page_content:
|
||||
if row_has_content and page_content:
|
||||
documents.append(
|
||||
Document(page_content=";".join(page_content), metadata={"source": self._file_path})
|
||||
)
|
||||
@ -89,6 +152,166 @@ class ExcelExtractor(BaseExtractor):
|
||||
|
||||
return documents
|
||||
|
||||
def _extract_images_from_sheet(
|
||||
self, sheet_name: str, sheet, valid_columns: set[int], min_row: int
|
||||
) -> dict[tuple[int, int], list[str]]:
|
||||
"""
|
||||
Extract embedded worksheet images and map them to their anchor cell.
|
||||
|
||||
Images are stored with deterministic keys derived from the source upload file,
|
||||
sheet, anchor cell, and content hash so retried tasks can reuse the same
|
||||
UploadFile rows and storage objects.
|
||||
"""
|
||||
if not self._tenant_id or not self._user_id or not self._source_file_id:
|
||||
return {}
|
||||
|
||||
images = getattr(sheet, "_images", None) or []
|
||||
image_candidates: list[SheetImageCandidate] = []
|
||||
|
||||
for image in images:
|
||||
marker = getattr(getattr(image, "anchor", None), "_from", None)
|
||||
row_idx = getattr(marker, "row", None)
|
||||
col_idx = getattr(marker, "col", None)
|
||||
if row_idx is None or col_idx is None:
|
||||
continue
|
||||
if row_idx + 1 < min_row or col_idx + 1 not in valid_columns:
|
||||
continue
|
||||
|
||||
image_bytes = self._get_image_bytes(image)
|
||||
if not image_bytes:
|
||||
continue
|
||||
|
||||
image_ext = self._get_image_extension(image)
|
||||
if not image_ext:
|
||||
continue
|
||||
|
||||
anchor_row = row_idx + 1
|
||||
anchor_column = col_idx + 1
|
||||
content_hash = self._hash_image_bytes(image_bytes)
|
||||
image_candidates.append(
|
||||
{
|
||||
"anchor": (anchor_row, anchor_column),
|
||||
"content_hash": content_hash,
|
||||
"file_key": self._build_image_file_key(
|
||||
sheet_name=sheet_name,
|
||||
anchor_row=anchor_row,
|
||||
anchor_column=anchor_column,
|
||||
content_hash=content_hash,
|
||||
image_ext=image_ext,
|
||||
),
|
||||
"image_bytes": image_bytes,
|
||||
"image_ext": image_ext,
|
||||
}
|
||||
)
|
||||
|
||||
if not image_candidates:
|
||||
return {}
|
||||
|
||||
image_map: dict[tuple[int, int], list[str]] = {}
|
||||
base_url = dify_config.FILES_URL
|
||||
candidate_keys = sorted({candidate["file_key"] for candidate in image_candidates})
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
existing_upload_files = session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == self._tenant_id,
|
||||
UploadFile.key.in_(candidate_keys),
|
||||
)
|
||||
).all()
|
||||
upload_files_by_key = {upload_file.key: upload_file for upload_file in existing_upload_files}
|
||||
new_upload_files: list[UploadFile] = []
|
||||
|
||||
for candidate in image_candidates:
|
||||
upload_file = upload_files_by_key.get(candidate["file_key"])
|
||||
if upload_file is None:
|
||||
storage.save(candidate["file_key"], candidate["image_bytes"])
|
||||
mime_type, _ = mimetypes.guess_type(candidate["file_key"])
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._tenant_id,
|
||||
storage_type=StorageType(dify_config.STORAGE_TYPE),
|
||||
key=candidate["file_key"],
|
||||
name=candidate["file_key"],
|
||||
size=len(candidate["image_bytes"]),
|
||||
extension=candidate["image_ext"],
|
||||
mime_type=mime_type or "",
|
||||
created_by=self._user_id,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_at=naive_utc_now(),
|
||||
used=True,
|
||||
used_by=self._user_id,
|
||||
used_at=naive_utc_now(),
|
||||
hash=candidate["content_hash"],
|
||||
)
|
||||
upload_files_by_key[candidate["file_key"]] = upload_file
|
||||
new_upload_files.append(upload_file)
|
||||
|
||||
image_map.setdefault(candidate["anchor"], []).append(
|
||||
f""
|
||||
)
|
||||
|
||||
if new_upload_files:
|
||||
session.add_all(new_upload_files)
|
||||
session.commit()
|
||||
|
||||
return image_map
|
||||
|
||||
@staticmethod
|
||||
def _hash_image_bytes(image_bytes: bytes) -> str:
|
||||
"""Return a stable content hash for extracted image bytes."""
|
||||
return hashlib.sha256(image_bytes).hexdigest()
|
||||
|
||||
def _build_image_file_key(
|
||||
self,
|
||||
*,
|
||||
sheet_name: str,
|
||||
anchor_row: int,
|
||||
anchor_column: int,
|
||||
content_hash: str,
|
||||
image_ext: str,
|
||||
) -> str:
|
||||
"""Build a deterministic storage key for an embedded worksheet image."""
|
||||
assert self._tenant_id is not None, "tenant_id is required for image extraction"
|
||||
assert self._source_file_id is not None, "source_file_id is required for image extraction"
|
||||
|
||||
normalized_ext = image_ext.strip().lower()
|
||||
sheet_hash = hashlib.sha256(sheet_name.encode("utf-8")).hexdigest()[:16]
|
||||
return (
|
||||
f"image_files/{self._tenant_id}/{self._source_file_id}/"
|
||||
f"{sheet_hash}_r{anchor_row}_c{anchor_column}_{content_hash}.{normalized_ext}"
|
||||
)
|
||||
|
||||
def _get_image_bytes(self, image) -> bytes | None:
|
||||
"""Return embedded image bytes from an openpyxl image object."""
|
||||
data_loader = getattr(image, "_data", None)
|
||||
if not callable(data_loader):
|
||||
return None
|
||||
|
||||
try:
|
||||
data = data_loader()
|
||||
if isinstance(data, bytes):
|
||||
return data
|
||||
if isinstance(data, bytearray):
|
||||
return bytes(data)
|
||||
logger.warning("Unexpected embedded image payload type: %s", type(data).__name__)
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning("Failed to read embedded image bytes from Excel sheet", exc_info=True)
|
||||
return None
|
||||
|
||||
def _get_image_extension(self, image) -> str | None:
|
||||
"""Resolve an image extension from openpyxl metadata."""
|
||||
image_format = getattr(image, "format", None)
|
||||
if isinstance(image_format, str) and image_format.strip():
|
||||
return image_format.strip().lower()
|
||||
|
||||
image_path = getattr(image, "path", None)
|
||||
if isinstance(image_path, str):
|
||||
_, extension = os.path.splitext(image_path)
|
||||
if extension:
|
||||
return extension.lstrip(".").lower()
|
||||
|
||||
return None
|
||||
|
||||
def _find_header_and_columns(self, sheet, scan_rows=10) -> tuple[int, dict[int, str], int]:
|
||||
"""
|
||||
Scan first N rows to find the most likely header row.
|
||||
|
||||
@ -113,7 +113,12 @@ class ExtractProcessor:
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or ""
|
||||
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
extractor = ExcelExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id,
|
||||
upload_file.created_by,
|
||||
upload_file.id,
|
||||
)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
@ -151,7 +156,12 @@ class ExtractProcessor:
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
extractor = ExcelExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id,
|
||||
upload_file.created_by,
|
||||
upload_file.id,
|
||||
)
|
||||
elif file_extension == ".pdf":
|
||||
assert upload_file is not None
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
|
||||
@ -11,12 +11,15 @@ class _FakeCell:
|
||||
def __init__(self, value, hyperlink=None):
|
||||
self.value = value
|
||||
self.hyperlink = hyperlink
|
||||
self.row = 0
|
||||
self.column = 0
|
||||
|
||||
|
||||
class _FakeSheet:
|
||||
def __init__(self, header_rows, data_rows):
|
||||
def __init__(self, header_rows, data_rows, images=None):
|
||||
self._header_rows = header_rows
|
||||
self._data_rows = data_rows
|
||||
self._images = images or []
|
||||
|
||||
def iter_rows(self, min_row=1, max_row=None, max_col=None, values_only=False):
|
||||
if values_only:
|
||||
@ -24,11 +27,12 @@ class _FakeSheet:
|
||||
yield tuple(row)
|
||||
return
|
||||
|
||||
for row in self._data_rows:
|
||||
if max_col is not None:
|
||||
yield tuple(row[:max_col])
|
||||
else:
|
||||
yield tuple(row)
|
||||
for row_idx, row in enumerate(self._data_rows, start=min_row):
|
||||
materialized_row = tuple(row[:max_col] if max_col is not None else row)
|
||||
for col_idx, cell in enumerate(materialized_row, start=1):
|
||||
cell.row = row_idx
|
||||
cell.column = col_idx
|
||||
yield materialized_row
|
||||
|
||||
|
||||
class _FakeWorkbook:
|
||||
@ -44,6 +48,94 @@ class _FakeWorkbook:
|
||||
self.closed = True
|
||||
|
||||
|
||||
class _FakeImage:
|
||||
def __init__(self, data: bytes, row: int, col: int, image_format: str = "png"):
|
||||
self._raw_data = data
|
||||
self.anchor = SimpleNamespace(_from=SimpleNamespace(row=row, col=col))
|
||||
self.format = image_format
|
||||
|
||||
def _data(self) -> bytes:
|
||||
return self._raw_data
|
||||
|
||||
|
||||
class _FieldExpression:
|
||||
def __eq__(self, other):
|
||||
return ("eq", other)
|
||||
|
||||
def in_(self, values):
|
||||
return ("in", tuple(values))
|
||||
|
||||
|
||||
class _SelectStub:
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
class _FakeUploadFile:
|
||||
tenant_id = _FieldExpression()
|
||||
key = _FieldExpression()
|
||||
_i = 0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
type(self)._i += 1
|
||||
self.id = f"u{self._i}"
|
||||
self.key = kwargs["key"]
|
||||
|
||||
|
||||
class _PersistentSession:
|
||||
def __init__(self, persisted):
|
||||
self._persisted = persisted
|
||||
self.added = []
|
||||
self.commit_count = 0
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def scalars(self, _stmt):
|
||||
return SimpleNamespace(all=lambda: list(self._persisted.values()))
|
||||
|
||||
def add_all(self, objects) -> None:
|
||||
self.added.extend(objects)
|
||||
|
||||
def commit(self) -> None:
|
||||
self.commit_count += 1
|
||||
for upload_file in self.added:
|
||||
self._persisted[upload_file.key] = upload_file
|
||||
self.added.clear()
|
||||
|
||||
|
||||
class _PersistentSessionFactory:
|
||||
def __init__(self):
|
||||
self.persisted = {}
|
||||
self.sessions = []
|
||||
|
||||
def create_session(self):
|
||||
session = _PersistentSession(self.persisted)
|
||||
self.sessions.append(session)
|
||||
return session
|
||||
|
||||
|
||||
def _patch_image_persistence(monkeypatch: pytest.MonkeyPatch):
|
||||
saves: list[tuple[str, bytes]] = []
|
||||
session_factory = _PersistentSessionFactory()
|
||||
|
||||
def save(key: str, data: bytes) -> None:
|
||||
saves.append((key, data))
|
||||
|
||||
_FakeUploadFile._i = 0
|
||||
monkeypatch.setattr(excel_module, "storage", SimpleNamespace(save=save))
|
||||
monkeypatch.setattr(excel_module, "session_factory", session_factory)
|
||||
monkeypatch.setattr(excel_module, "select", lambda *args, **kwargs: _SelectStub())
|
||||
monkeypatch.setattr(excel_module, "UploadFile", _FakeUploadFile)
|
||||
monkeypatch.setattr(excel_module.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
monkeypatch.setattr(excel_module.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||
|
||||
return saves, session_factory
|
||||
|
||||
|
||||
class TestExcelExtractor:
|
||||
def test_extract_xlsx_with_hyperlinks_and_sheet_skip(self, monkeypatch: pytest.MonkeyPatch):
|
||||
sheet_with_data = _FakeSheet(
|
||||
@ -68,6 +160,121 @@ class TestExcelExtractor:
|
||||
assert docs[1].page_content == '"Name":"";"Link":"123"'
|
||||
assert all(doc.metadata["source"] == "/tmp/sample.xlsx" for doc in docs)
|
||||
|
||||
def test_extract_xlsx_turns_embedded_images_into_markdown_links(self, monkeypatch: pytest.MonkeyPatch):
|
||||
image_bytes = b"\x89PNG\r\n\x1a\nexcel-image"
|
||||
sheet = _FakeSheet(
|
||||
header_rows=[("Question", "Answer", "Image")],
|
||||
data_rows=[
|
||||
(_FakeCell("Q1"), _FakeCell("A1"), _FakeCell(None)),
|
||||
(_FakeCell("Q2"), _FakeCell("A2"), _FakeCell(None)),
|
||||
],
|
||||
images=[
|
||||
_FakeImage(image_bytes, row=1, col=2),
|
||||
_FakeImage(image_bytes, row=1, col=2),
|
||||
],
|
||||
)
|
||||
workbook = _FakeWorkbook({"Data": sheet})
|
||||
monkeypatch.setattr(excel_module, "load_workbook", lambda *args, **kwargs: workbook)
|
||||
saves, session_factory = _patch_image_persistence(monkeypatch)
|
||||
|
||||
extractor = ExcelExtractor(
|
||||
"/tmp/sample.xlsx",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
source_file_id="source-file-1",
|
||||
)
|
||||
docs = extractor.extract()
|
||||
|
||||
assert workbook.closed is True
|
||||
assert len(docs) == 2
|
||||
assert docs[0].page_content == (
|
||||
'"Question":"Q1";"Answer":"A1";'
|
||||
'"Image":" '
|
||||
'"'
|
||||
)
|
||||
assert docs[1].page_content == '"Question":"Q2";"Answer":"A2";"Image":""'
|
||||
assert len(saves) == 1
|
||||
assert saves[0][0].startswith("image_files/tenant-1/source-file-1/")
|
||||
assert saves[0][0].endswith(".png")
|
||||
assert saves[0][1] == image_bytes
|
||||
assert len(session_factory.persisted) == 1
|
||||
assert [session.commit_count for session in session_factory.sessions] == [1]
|
||||
|
||||
def test_extract_xlsx_keeps_rows_with_only_embedded_images(self, monkeypatch: pytest.MonkeyPatch):
|
||||
image_bytes = b"\x89PNG\r\n\x1a\nimage-only-row"
|
||||
sheet = _FakeSheet(
|
||||
header_rows=[("Question", "Answer", "Image")],
|
||||
data_rows=[
|
||||
(_FakeCell(None), _FakeCell(None), _FakeCell(None)),
|
||||
(_FakeCell(None), _FakeCell(None), _FakeCell(None)),
|
||||
],
|
||||
images=[_FakeImage(image_bytes, row=1, col=2)],
|
||||
)
|
||||
workbook = _FakeWorkbook({"Data": sheet})
|
||||
monkeypatch.setattr(excel_module, "load_workbook", lambda *args, **kwargs: workbook)
|
||||
saves, session_factory = _patch_image_persistence(monkeypatch)
|
||||
|
||||
extractor = ExcelExtractor(
|
||||
"/tmp/sample.xlsx",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
source_file_id="source-file-1",
|
||||
)
|
||||
docs = extractor.extract()
|
||||
|
||||
assert workbook.closed is True
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == (
|
||||
'"Question":"";"Answer":"";"Image":""'
|
||||
)
|
||||
assert len(saves) == 1
|
||||
assert len(session_factory.persisted) == 1
|
||||
assert [session.commit_count for session in session_factory.sessions] == [1]
|
||||
|
||||
def test_extract_xlsx_reuses_existing_embedded_image_uploads_on_retry(self, monkeypatch: pytest.MonkeyPatch):
|
||||
image_bytes = b"\x89PNG\r\n\x1a\nretry-safe-image"
|
||||
workbooks = [
|
||||
_FakeWorkbook(
|
||||
{
|
||||
"Data": _FakeSheet(
|
||||
header_rows=[("Question", "Answer", "Image")],
|
||||
data_rows=[(_FakeCell("Q1"), _FakeCell("A1"), _FakeCell(None))],
|
||||
images=[_FakeImage(image_bytes, row=1, col=2)],
|
||||
)
|
||||
}
|
||||
),
|
||||
_FakeWorkbook(
|
||||
{
|
||||
"Data": _FakeSheet(
|
||||
header_rows=[("Question", "Answer", "Image")],
|
||||
data_rows=[(_FakeCell("Q1"), _FakeCell("A1"), _FakeCell(None))],
|
||||
images=[_FakeImage(image_bytes, row=1, col=2)],
|
||||
)
|
||||
}
|
||||
),
|
||||
]
|
||||
monkeypatch.setattr(excel_module, "load_workbook", lambda *args, **kwargs: workbooks.pop(0))
|
||||
saves, session_factory = _patch_image_persistence(monkeypatch)
|
||||
|
||||
extractor = ExcelExtractor(
|
||||
"/tmp/sample.xlsx",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
source_file_id="source-file-1",
|
||||
)
|
||||
first_docs = extractor.extract()
|
||||
second_docs = extractor.extract()
|
||||
|
||||
expected_page_content = (
|
||||
'"Question":"Q1";"Answer":"A1";"Image":""'
|
||||
)
|
||||
|
||||
assert first_docs[0].page_content == expected_page_content
|
||||
assert second_docs[0].page_content == expected_page_content
|
||||
assert len(saves) == 1
|
||||
assert len(session_factory.persisted) == 1
|
||||
assert [session.commit_count for session in session_factory.sessions] == [1, 0]
|
||||
|
||||
def test_extract_xls_path(self, monkeypatch: pytest.MonkeyPatch):
|
||||
class FakeExcelFile:
|
||||
sheet_names = ["Sheet1"]
|
||||
|
||||
@ -139,7 +139,12 @@ class TestExtractProcessorFileRouting:
|
||||
|
||||
setting = SimpleNamespace(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=SimpleNamespace(key=f"uploaded{extension}", tenant_id="tenant-1", created_by="user-1"),
|
||||
upload_file=SimpleNamespace(
|
||||
id="upload-file-1",
|
||||
key=f"uploaded{extension}",
|
||||
tenant_id="tenant-1",
|
||||
created_by="user-1",
|
||||
),
|
||||
)
|
||||
|
||||
docs = ExtractProcessor.extract(setting, is_automatic=is_automatic)
|
||||
@ -200,6 +205,13 @@ class TestExtractProcessorFileRouting:
|
||||
|
||||
assert extractor_name == expected_extractor
|
||||
|
||||
def test_extract_routes_excel_with_upload_context(self, monkeypatch: pytest.MonkeyPatch):
|
||||
extractor_name, args, kwargs = self._run_extract_for_extension(monkeypatch, ".xlsx", etl_type="SelfHosted")
|
||||
|
||||
assert extractor_name == "ExcelExtractor"
|
||||
assert args[1:] == ("tenant-1", "user-1", "upload-file-1")
|
||||
assert kwargs == {}
|
||||
|
||||
def test_extract_requires_upload_file_when_file_path_not_provided(self):
|
||||
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ for the full indexing pipeline are handled separately in the integration test su
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
@ -1424,6 +1425,42 @@ class TestIndexingRunnerEstimate:
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
def test_indexing_estimate_commits_preview_image_cleanup(self, mock_dependencies):
|
||||
"""Test indexing estimate persists cleanup for preview-only extracted images."""
|
||||
runner = IndexingRunner()
|
||||
tenant_id = str(uuid.uuid4())
|
||||
mock_processor = MagicMock()
|
||||
mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor
|
||||
|
||||
preview_doc = Document(
|
||||
page_content="",
|
||||
metadata={},
|
||||
)
|
||||
mock_processor.extract.return_value = [preview_doc]
|
||||
mock_processor.transform.return_value = [preview_doc]
|
||||
|
||||
image_file = SimpleNamespace(key="image_files/tenant-1/source-file-1/image.png")
|
||||
mock_dependencies["db"].session.scalar.return_value = image_file
|
||||
|
||||
with (
|
||||
patch("core.indexing_runner.get_image_upload_file_ids", return_value=["image-1"]),
|
||||
patch("core.indexing_runner.storage") as mock_storage,
|
||||
patch("core.indexing_runner.dify_config") as mock_config,
|
||||
):
|
||||
mock_config.BILLING_ENABLED = False
|
||||
|
||||
result = runner.indexing_estimate(
|
||||
tenant_id=tenant_id,
|
||||
extract_settings=[MagicMock()],
|
||||
tmp_processing_rule={"mode": "automatic", "rules": {}},
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
|
||||
assert result.total_segments == 1
|
||||
mock_storage.delete.assert_called_once_with(image_file.key)
|
||||
mock_dependencies["db"].session.delete.assert_called_once_with(image_file)
|
||||
mock_dependencies["db"].session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestIndexingRunnerProcessChunk:
|
||||
"""Unit tests for chunk processing in parallel.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user