fix multimodal embedding retrival test

This commit is contained in:
jyong 2025-12-26 17:05:37 +08:00
parent 901cc64ac9
commit 676063890c
1 changed files with 8 additions and 8 deletions

View File

@ -473,9 +473,7 @@ class RetrievalService:
for segment in segments: for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get( ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
segment.document_id
)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids: if segment.id not in include_segment_ids:
@ -495,8 +493,10 @@ class RetrievalService:
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0) max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
for attachment_info in attachment_infos: for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]] file_document = doc_to_document_map[attachment_info["id"]]
max_score = max(max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0) max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
map_detail = { map_detail = {
"max_score": max_score, "max_score": max_score,
"child_chunks": child_chunk_details, "child_chunks": child_chunk_details,
@ -522,7 +522,7 @@ class RetrievalService:
"score": max_score, "score": max_score,
} }
records.append(record) records.append(record)
# Add child chunks information to records # Add child chunks information to records
for record in records: for record in records:
if record["segment"].id in segment_child_map: if record["segment"].id in segment_child_map:
@ -540,7 +540,7 @@ class RetrievalService:
child_chunks = record.get("child_chunks") child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list): if not isinstance(child_chunks, list):
child_chunks = None child_chunks = None
if child_chunks: if child_chunks:
child_chunks = sorted(child_chunks, key=lambda x: x.get, reverse=True) child_chunks = sorted(child_chunks, key=lambda x: x.get, reverse=True)
@ -563,7 +563,7 @@ class RetrievalService:
) )
result.append(retrieval_segment) result.append(retrieval_segment)
return sorted(result, key=lambda x: x.score, reverse=True) return sorted(result, key=lambda x: x.score, reverse=True)
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
raise e raise e