diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index c4f28ae216..f68e5a4e6b 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -5,6 +5,8 @@ import re import uuid from typing import Any, TypedDict, cast, override +from sqlalchemy.orm import scoped_session + logger = logging.getLogger(__name__) from sqlalchemy import select @@ -411,11 +413,13 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if supports_vision: # First, try to get images from SegmentAttachmentBinding (preferred method) if segment_id: - image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id) + image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments( + tenant_id, segment_id, db.session + ) # If no images from attachments, fall back to extracting from text if not image_files: - image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text) + image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text, db.session) # Build prompt messages prompt_messages = [] @@ -469,7 +473,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return summary_content, usage @staticmethod - def _extract_images_from_text(tenant_id: str, text: str) -> list[File]: + def _extract_images_from_text(tenant_id: str, text: str, session: scoped_session) -> list[File]: """ Extract images from markdown text and convert them to File objects. @@ -518,7 +522,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Get unique IDs for database query unique_upload_file_ids = list(set(upload_file_id_list)) - upload_files = db.session.scalars( + upload_files = session.scalars( select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) ).all() @@ -549,7 +553,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return file_objects @staticmethod - def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]: + def _extract_images_from_segment_attachments( + tenant_id: str, segment_id: str, session: scoped_session + ) -> list[File]: """ Extract images from SegmentAttachmentBinding table (preferred method). This matches how DatasetRetrieval gets segment attachments. @@ -564,7 +570,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): from sqlalchemy import select # Query attachments from SegmentAttachmentBinding table - attachments_with_bindings = db.session.execute( + attachments_with_bindings = session.execute( select(SegmentAttachmentBinding, UploadFile) .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) .where( diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 4368e9cddc..182930b19d 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -528,21 +528,24 @@ class TestParagraphIndexProcessor: session.scalars.return_value = scalars_result with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), patch( "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", return_value=SimpleNamespace(id="file-1"), ) as mock_builder, patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session) assert len(files) == 1 assert mock_builder.call_count == 1 mock_logger.warning.assert_not_called() def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None: - assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here") == [] + scalars_result = Mock() + scalars_result.all.return_value = [] + session = Mock() + session.scalars.return_value = scalars_result + assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here", session) == [] def test_extract_images_from_text_logs_when_build_fails(self) -> None: text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)" @@ -562,14 +565,13 @@ class TestParagraphIndexProcessor: session.scalars.return_value = scalars_result with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), patch( "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", side_effect=RuntimeError("build failed"), ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text, session) assert files == [] mock_logger.warning.assert_called_once() @@ -608,10 +610,9 @@ class TestParagraphIndexProcessor: session.execute.return_value = execute_result with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1", session) assert len(files) == 1 mock_logger.warning.assert_called_once() @@ -622,7 +623,6 @@ class TestParagraphIndexProcessor: session = Mock() session.execute.return_value = execute_result - with patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session): - empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1", session) assert empty_files == []