mirror of https://github.com/langgenius/dify.git
fix: retrieval test and knowledge retrieval node failed in multimodal mode (#30210)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
d20a8d5b77
commit
f610f6895f
|
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
|
|
@ -381,10 +381,9 @@ class RetrievalService:
|
|||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids = []
|
||||
image_doc_ids: list[Any] = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
|
|
@ -417,28 +416,39 @@ class RetrievalService:
|
|||
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||
index_node_ids = [i for i in index_node_ids if i]
|
||||
|
||||
segment_ids = []
|
||||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map = {}
|
||||
child_chunk_map = {}
|
||||
doc_segment_map = {}
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
|
||||
for attachment in attachments:
|
||||
segment_ids.append(attachment["segment_id"])
|
||||
attachment_map[attachment["segment_id"]] = attachment
|
||||
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
|
||||
|
||||
if attachment["segment_id"] in attachment_map:
|
||||
attachment_map[attachment["segment_id"]].append(attachment["attachment_info"])
|
||||
else:
|
||||
attachment_map[attachment["segment_id"]] = [attachment["attachment_info"]]
|
||||
if attachment["segment_id"] in doc_segment_map:
|
||||
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||
else:
|
||||
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
segment_ids.append(i.segment_id)
|
||||
child_chunk_map[i.segment_id] = i
|
||||
doc_segment_map[i.segment_id] = i.index_node_id
|
||||
if i.segment_id in child_chunk_map:
|
||||
child_chunk_map[i.segment_id].append(i)
|
||||
else:
|
||||
child_chunk_map[i.segment_id] = [i]
|
||||
if i.segment_id in doc_segment_map:
|
||||
doc_segment_map[i.segment_id].append(i.index_node_id)
|
||||
else:
|
||||
doc_segment_map[i.segment_id] = [i.index_node_id]
|
||||
|
||||
if index_node_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
|
|
@ -448,7 +458,7 @@ class RetrievalService:
|
|||
)
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
|
||||
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
|
|
@ -461,95 +471,86 @@ class RetrievalService:
|
|||
segments.extend(index_node_segments)
|
||||
|
||||
for segment in segments:
|
||||
doc_id = doc_segment_map.get(segment.id)
|
||||
child_chunk = child_chunk_map.get(segment.id)
|
||||
attachment_info = attachment_map.get(segment.id)
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
|
||||
|
||||
if doc_id:
|
||||
document = doc_to_document_map[doc_id]
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
|
||||
document.metadata.get("document_id")
|
||||
)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
if child_chunk:
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
if child_chunks or attachment_infos:
|
||||
child_chunk_details = []
|
||||
max_score = 0.0
|
||||
for child_chunk in child_chunks:
|
||||
document = doc_to_document_map[child_chunk.index_node_id]
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record = {
|
||||
"segment": segment,
|
||||
child_chunk_details.append(child_chunk_detail)
|
||||
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||
for attachment_info in attachment_infos:
|
||||
file_document = doc_to_document_map[attachment_info["id"]]
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
|
||||
map_detail = {
|
||||
"max_score": max_score,
|
||||
"child_chunks": child_chunk_details,
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
records.append(record)
|
||||
else:
|
||||
if child_chunk:
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
if segment.id in segment_child_map:
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"],
|
||||
document.metadata.get("score", 0.0) if document else 0.0,
|
||||
)
|
||||
else:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
if attachment_info:
|
||||
if segment.id in segment_file_map:
|
||||
segment_file_map[segment.id].append(attachment_info)
|
||||
else:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score", 0.0), # type: ignore
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
records.append(record)
|
||||
else:
|
||||
if attachment_info:
|
||||
attachment_infos = segment_file_map.get(segment.id, [])
|
||||
if attachment_info not in attachment_infos:
|
||||
attachment_infos.append(attachment_info)
|
||||
segment_file_map[segment.id] = attachment_infos
|
||||
segment_child_map[segment.id] = map_detail
|
||||
record: dict[str, Any] = {
|
||||
"segment": segment,
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
max_score = 0.0
|
||||
segment_document = doc_to_document_map.get(segment.index_node_id)
|
||||
if segment_document:
|
||||
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
|
||||
for attachment_info in attachment_infos:
|
||||
file_doc = doc_to_document_map.get(attachment_info["id"])
|
||||
if file_doc:
|
||||
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": max_score,
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
# Add child chunks information to records
|
||||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
if record["segment"].id in segment_file_map:
|
||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
||||
if record["segment"].id in attachment_map:
|
||||
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
result = []
|
||||
result: list[RetrievalSegments] = []
|
||||
for record in records:
|
||||
# Extract segment
|
||||
segment = record["segment"]
|
||||
|
||||
# Extract child_chunks, ensuring it's a list or None
|
||||
child_chunks = record.get("child_chunks")
|
||||
if not isinstance(child_chunks, list):
|
||||
child_chunks = None
|
||||
raw_child_chunks = record.get("child_chunks")
|
||||
child_chunks_list: list[RetrievalChildChunk] | None = None
|
||||
if isinstance(raw_child_chunks, list):
|
||||
# Sort by score descending
|
||||
sorted_chunks = sorted(raw_child_chunks, key=lambda x: x.get("score", 0.0), reverse=True)
|
||||
child_chunks_list = [
|
||||
RetrievalChildChunk(
|
||||
id=chunk["id"],
|
||||
content=chunk["content"],
|
||||
score=chunk.get("score", 0.0),
|
||||
position=chunk["position"],
|
||||
)
|
||||
for chunk in sorted_chunks
|
||||
]
|
||||
|
||||
# Extract files, ensuring it's a list or None
|
||||
files = record.get("files")
|
||||
|
|
@ -566,11 +567,11 @@ class RetrievalService:
|
|||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks, score=score, files=files
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
return result
|
||||
return sorted(result, key=lambda x: x.score if x.score is not None else 0.0, reverse=True)
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
raise e
|
||||
|
|
|
|||
Loading…
Reference in New Issue