From f610f6895f3cd0fe3a1a47b68cb3136f4a250b21 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 26 Dec 2025 21:42:06 +0800 Subject: [PATCH] fix: retrieval test and knowledge retrieval node failed in multimodal mode (#30210) Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/rag/datasource/retrieval_service.py | 169 ++++++++++--------- 1 file changed, 85 insertions(+), 84 deletions(-) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 9807cb4e6a..43912cd75d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -381,10 +381,9 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - segment_file_map = {} valid_dataset_documents = {} - image_doc_ids = [] + image_doc_ids: list[Any] = [] child_index_node_ids = [] index_node_ids = [] doc_to_document_map = {} @@ -417,28 +416,39 @@ class RetrievalService: child_index_node_ids = [i for i in child_index_node_ids if i] index_node_ids = [i for i in index_node_ids if i] - segment_ids = [] + segment_ids: list[str] = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map = {} - child_chunk_map = {} - doc_segment_map = {} + attachment_map: dict[str, list[dict[str, Any]]] = {} + child_chunk_map: dict[str, list[ChildChunk]] = {} + doc_segment_map: dict[str, list[str]] = {} with session_factory.create_session() as session: attachments = cls.get_segment_attachment_infos(image_doc_ids, session) for attachment in attachments: segment_ids.append(attachment["segment_id"]) - attachment_map[attachment["segment_id"]] = attachment - doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"] - + if attachment["segment_id"] in attachment_map: + attachment_map[attachment["segment_id"]].append(attachment["attachment_info"]) + else: + attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]] + if attachment["segment_id"] in doc_segment_map: + doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"]) + else: + doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]] child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) child_index_nodes = session.execute(child_chunk_stmt).scalars().all() for i in child_index_nodes: segment_ids.append(i.segment_id) - child_chunk_map[i.segment_id] = i - doc_segment_map[i.segment_id] = i.index_node_id + if i.segment_id in child_chunk_map: + child_chunk_map[i.segment_id].append(i) + else: + child_chunk_map[i.segment_id] = [i] + if i.segment_id in doc_segment_map: + doc_segment_map[i.segment_id].append(i.index_node_id) + else: + doc_segment_map[i.segment_id] = [i.index_node_id] if index_node_ids: document_segment_stmt = select(DocumentSegment).where( @@ -448,7 +458,7 @@ class RetrievalService: ) index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore for index_node_segment in index_node_segments: - doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id + doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id] if segment_ids: document_segment_stmt = select(DocumentSegment).where( DocumentSegment.enabled == True, @@ -461,95 +471,86 @@ class RetrievalService: segments.extend(index_node_segments) for segment in segments: - doc_id = doc_segment_map.get(segment.id) - child_chunk = child_chunk_map.get(segment.id) - attachment_info = attachment_map.get(segment.id) + child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) + attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) + ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id) - if doc_id: - document = doc_to_document_map[doc_id] - ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get( - document.metadata.get("document_id") - ) - - if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - if child_chunk: + if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunks or attachment_infos: + child_chunk_details = [] + max_score = 0.0 + for child_chunk in child_chunks: + document = doc_to_document_map[child_chunk.index_node_id] child_chunk_detail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, "score": document.metadata.get("score", 0.0) if document else 0.0, } - map_detail = { - "max_score": document.metadata.get("score", 0.0) if document else 0.0, - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, + child_chunk_details.append(child_chunk_detail) + max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) + for attachment_info in attachment_infos: + file_document = doc_to_document_map[attachment_info["id"]] + max_score = max( + max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0 + ) + + map_detail = { + "max_score": max_score, + "child_chunks": child_chunk_details, } - if attachment_info: - segment_file_map[segment.id] = [attachment_info] - records.append(record) - else: - if child_chunk: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - if segment.id in segment_child_map: - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], - document.metadata.get("score", 0.0) if document else 0.0, - ) - else: - segment_child_map[segment.id] = { - "max_score": document.metadata.get("score", 0.0) if document else 0.0, - "child_chunks": [child_chunk_detail], - } - if attachment_info: - if segment.id in segment_file_map: - segment_file_map[segment.id].append(attachment_info) - else: - segment_file_map[segment.id] = [attachment_info] - else: - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score", 0.0), # type: ignore - } - if attachment_info: - segment_file_map[segment.id] = [attachment_info] - records.append(record) - else: - if attachment_info: - attachment_infos = segment_file_map.get(segment.id, []) - if attachment_info not in attachment_infos: - attachment_infos.append(attachment_info) - segment_file_map[segment.id] = attachment_infos + segment_child_map[segment.id] = map_detail + record: dict[str, Any] = { + "segment": segment, + } + records.append(record) + else: + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + max_score = 0.0 + segment_document = doc_to_document_map.get(segment.index_node_id) + if segment_document: + max_score = max(max_score, segment_document.metadata.get("score", 0.0)) + for attachment_info in attachment_infos: + file_doc = doc_to_document_map.get(attachment_info["id"]) + if file_doc: + max_score = max(max_score, file_doc.metadata.get("score", 0.0)) + record = { + "segment": segment, + "score": max_score, + } + records.append(record) # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore - if record["segment"].id in segment_file_map: - record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] + if record["segment"].id in attachment_map: + record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment] - result = [] + result: list[RetrievalSegments] = [] for record in records: # Extract segment segment = record["segment"] # Extract child_chunks, ensuring it's a list or None - child_chunks = record.get("child_chunks") - if not isinstance(child_chunks, list): - child_chunks = None + raw_child_chunks = record.get("child_chunks") + child_chunks_list: list[RetrievalChildChunk] | None = None + if isinstance(raw_child_chunks, list): + # Sort by score descending + sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True) + child_chunks_list = [ + RetrievalChildChunk( + id=chunk["id"], + content=chunk["content"], + score=chunk.get("score", 0.0), + position=chunk["position"], + ) + for chunk in sorted_chunks + ] # Extract files, ensuring it's a list or None files = record.get("files") @@ -566,11 +567,11 @@ class RetrievalService: # Create RetrievalSegments object retrieval_segment = RetrievalSegments( - segment=segment, child_chunks=child_chunks, score=score, files=files + segment=segment, child_chunks=child_chunks_list, score=score, files=files ) result.append(retrieval_segment) - return result + return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True) except Exception as e: db.session.rollback() raise e