refactor(api): replace dict/Mapping with TypedDict in core/rag retrieval_service.py (#33615)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
BitToby 2026-03-18 04:49:09 +02:00 committed by GitHub
parent d7f70f3c0f
commit 485da15a4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 165 additions and 71 deletions

View File

@ -8,6 +8,7 @@ from core.app.app_config.entities import (
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from models.model import AppMode, AppModelConfigDict
from services.dataset_service import DatasetService
@ -117,8 +118,10 @@ class DatasetConfigManager:
score_threshold=float(score_threshold_val)
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
else None,
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
weights=weights_val if isinstance(weights_val, dict) else None,
reranking_model=cast(RerankingModelDict, reranking_model_val)
if isinstance(reranking_model_val, dict)
else None,
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=cast(

View File

@ -4,6 +4,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from dify_graph.file import FileUploadConfig
from dify_graph.model_runtime.entities.llm_entities import LLMMode
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel):
top_k: int | None = None
score_threshold: float | None = 0.0
rerank_mode: str | None = "reranking_model"
reranking_model: dict | None = None
weights: dict | None = None
reranking_model: RerankingModelDict | None = None
weights: WeightsDict | None = None
reranking_enabled: bool | None = True
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
metadata_model_config: ModelConfig | None = None

View File

@ -1,3 +1,5 @@
from typing_extensions import TypedDict
from core.model_manager import ModelInstance, ModelManager
from core.rag.data_post_processor.reorder import ReorderRunner
from core.rag.index_processor.constant.query_type import QueryType
@ -10,6 +12,26 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
class RerankingModelDict(TypedDict):
reranking_provider_name: str
reranking_model_name: str
class VectorSettingDict(TypedDict):
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSettingDict(TypedDict):
keyword_weight: float
class WeightsDict(TypedDict):
vector_setting: VectorSettingDict
keyword_setting: KeywordSettingDict
class DataPostProcessor:
"""Interface for data post-processing document."""
@ -17,8 +39,8 @@ class DataPostProcessor:
self,
tenant_id: str,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reorder_enabled: bool = False,
):
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
@ -45,8 +67,8 @@ class DataPostProcessor:
self,
reranking_mode: str,
tenant_id: str,
reranking_model: dict | None = None,
weights: dict | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
) -> BaseRerankRunner | None:
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
runner = RerankRunnerFactory.create_rerank_runner(
@ -79,12 +101,14 @@ class DataPostProcessor:
return ReorderRunner()
return None
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
def _get_rerank_model_instance(
self, tenant_id: str, reranking_model: RerankingModelDict | None
) -> ModelInstance | None:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
reranking_provider_name = reranking_model["reranking_provider_name"]
reranking_model_name = reranking_model["reranking_model_name"]
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(

View File

@ -1,19 +1,20 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Any, NotRequired
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from typing_extensions import TypedDict
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -35,7 +36,46 @@ from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
class SegmentAttachmentResult(TypedDict):
attachment_info: AttachmentInfoDict
segment_id: str
class SegmentAttachmentInfoResult(TypedDict):
attachment_id: str
attachment_info: AttachmentInfoDict
segment_id: str
class ChildChunkDetail(TypedDict):
id: str
content: str
position: int
score: float
class SegmentChildMapDetail(TypedDict):
max_score: float
child_chunks: list[ChildChunkDetail]
class SegmentRecord(TypedDict):
segment: DocumentSegment
score: NotRequired[float]
child_chunks: NotRequired[list[ChildChunkDetail]]
files: NotRequired[list[AttachmentInfoDict]]
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod | str
reranking_enable: bool
reranking_model: RerankingModelDict
top_k: int
score_threshold_enabled: bool
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -56,9 +96,9 @@ class RetrievalService:
query: str,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
):
@ -235,7 +275,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: RetrievalMethod,
exceptions: list,
@ -277,8 +317,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
@ -288,8 +328,8 @@ class RetrievalService:
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=dataset.tenant_id,
provider=reranking_model.get("reranking_provider_name") or "",
model=reranking_model.get("reranking_model_name") or "",
provider=reranking_model["reranking_provider_name"],
model=reranking_model["reranking_model_name"],
model_type=ModelType.RERANK,
)
if is_support_vision:
@ -329,7 +369,7 @@ class RetrievalService:
query: str,
top_k: int,
score_threshold: float | None,
reranking_model: dict | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
retrieval_method: str,
exceptions: list,
@ -349,8 +389,8 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and reranking_model["reranking_model_name"]
and reranking_model["reranking_provider_name"]
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
@ -459,7 +499,7 @@ class RetrievalService:
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[dict[str, Any]]] = {}
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
@ -544,12 +584,12 @@ class RetrievalService:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
segment_child_map: dict[str, SegmentChildMapDetail] = {}
records: list[SegmentRecord] = []
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
@ -560,14 +600,14 @@ class RetrievalService:
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos:
child_chunk_details = []
child_chunk_details: list[ChildChunkDetail] = []
for child_chunk in child_chunks:
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
if child_document:
child_score = child_document.metadata.get("score", 0.0)
else:
child_score = 0.0
child_chunk_detail = {
child_chunk_detail: ChildChunkDetail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
@ -580,7 +620,7 @@ class RetrievalService:
if file_document:
max_score = max(max_score, file_document.metadata.get("score", 0.0))
map_detail = {
map_detail: SegmentChildMapDetail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
@ -593,7 +633,7 @@ class RetrievalService:
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
record: SegmentRecord = {
"segment": segment,
}
records.append(record)
@ -617,19 +657,19 @@ class RetrievalService:
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
another_record: SegmentRecord = {
"segment": segment,
"score": max_score,
}
records.append(record)
records.append(another_record)
# Add child chunks information to records
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
record["score"] = segment_child_map[record["segment"].id]["max_score"]
if record["segment"].id in attachment_map:
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
record["files"] = attachment_map[record["segment"].id]
result: list[RetrievalSegments] = []
for record in records:
@ -693,9 +733,9 @@ class RetrievalService:
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
reranking_model: dict | None = None,
reranking_model: RerankingModelDict | None = None,
reranking_mode: str = "reranking_model",
weights: dict | None = None,
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
):
@ -807,7 +847,7 @@ class RetrievalService:
@classmethod
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> dict[str, Any] | None:
) -> SegmentAttachmentResult | None:
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = (
@ -816,7 +856,7 @@ class RetrievalService:
.first()
)
if attachment_binding:
attachment_info = {
attachment_info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -828,8 +868,10 @@ class RetrievalService:
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
def get_segment_attachment_infos(
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
@ -843,7 +885,7 @@ class RetrievalService:
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
info: AttachmentInfoDict = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
@ -855,7 +897,7 @@ class RetrievalService:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"attachment_info": info,
"segment_id": attachment_binding.segment_id,
}
)

View File

@ -1,8 +1,18 @@
from pydantic import BaseModel
from typing_extensions import TypedDict
from models.dataset import DocumentSegment
class AttachmentInfoDict(TypedDict):
id: str
name: str
extension: str
mime_type: str
source_url: str
size: int
class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""
@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel):
segment: DocumentSegment
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
files: list[AttachmentInfoDict] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@ -15,6 +15,7 @@ import httpx
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, Document
@ -98,7 +99,7 @@ class BaseIndexProcessor(ABC):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
raise NotImplementedError

View File

@ -14,6 +14,7 @@ from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(

View File

@ -11,6 +11,7 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(

View File

@ -15,6 +15,7 @@ from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.data_post_processor.data_post_processor import RerankingModelDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor):
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
reranking_model: RerankingModelDict,
):
# Set search parameters.
results = RetrievalService.retrieve(

View File

@ -31,7 +31,7 @@ from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
@ -727,8 +727,8 @@ class DatasetRetrieval:
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: dict | None = None,
weights: dict[str, Any] | None = None,
reranking_model: RerankingModelDict | None = None,
weights: WeightsDict | None = None,
reranking_enable: bool = True,
message_id: str | None = None,
metadata_filter_document_ids: dict[str, list[str]] | None = None,
@ -1181,8 +1181,8 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"],
reranking_model_name=retrieve_config.reranking_model["reranking_model_name"],
)
tools.append(tool)
@ -1685,8 +1685,8 @@ class DatasetRetrieval:
tenant_id: str,
reranking_enable: bool,
reranking_mode: str,
reranking_model: dict | None,
weights: dict[str, Any] | None,
reranking_model: RerankingModelDict | None,
weights: WeightsDict | None,
top_k: int,
score_threshold: float,
query: str | None,

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
@ -20,9 +21,9 @@ from services.external_knowledge_service import ExternalDatasetService
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: dict[str, str]
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[dict[str, object] | None]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool

View File

@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict
@ -201,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
reranking_model: RerankingModelDict | None = None
weights: WeightsDict | None = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:

View File

@ -2,6 +2,7 @@ from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from dify_graph.model_runtime.entities import LLMUsage
from dify_graph.nodes.llm.entities import ModelConfig
@ -75,8 +76,8 @@ class KnowledgeRetrievalRequest(BaseModel):
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration")
weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")

View File

@ -510,7 +510,7 @@ class TestWorkflowConverter:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=10,
score_threshold=0.8,
reranking_model={"provider": "cohere", "model": "rerank-v2"},
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
reranking_enabled=True,
),
)
@ -543,8 +543,8 @@ class TestWorkflowConverter:
multiple_config = node["data"]["multiple_retrieval_config"]
assert multiple_config["top_k"] == 10
assert multiple_config["score_threshold"] == 0.8
assert multiple_config["reranking_model"]["provider"] == "cohere"
assert multiple_config["reranking_model"]["model"] == "rerank-v2"
assert multiple_config["reranking_model"]["reranking_provider_name"] == "cohere"
assert multiple_config["reranking_model"]["reranking_model_name"] == "rerank-v2"
# Verify single retrieval config is None for multiple strategy
assert node["data"]["single_retrieval_config"] is None

View File

@ -236,7 +236,8 @@ class TestParagraphIndexProcessor:
"core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [accepted, rejected]
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].metadata["score"] == 0.9

View File

@ -307,7 +307,8 @@ class TestParentChildIndexProcessor:
"core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve"
) as mock_retrieve:
mock_retrieve.return_value = [ok_result, low_result]
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, {})
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].page_content == "keep"

View File

@ -262,7 +262,8 @@ class TestQAIndexProcessor:
with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve:
mock_retrieve.return_value = [result_ok, result_low]
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, {})
reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""}
docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model)
assert len(docs) == 1
assert docs[0].page_content == "accepted"

View File

@ -25,6 +25,7 @@ from core.app.app_config.entities import ModelConfig as WorkflowModelConfig
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.rag.data_post_processor.data_post_processor import WeightsDict
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -4686,7 +4687,10 @@ class TestSingleAndMultipleRetrieveCoverage:
extra={"dataset_name": "Ext", "title": "Ext"},
)
app = Flask(__name__)
weights = {"vector_setting": {}}
weights: WeightsDict = {
"vector_setting": {"vector_weight": 0.5, "embedding_provider_name": "", "embedding_model_name": ""},
"keyword_setting": {"keyword_weight": 0.5},
}
def fake_multiple_thread(**kwargs):
if kwargs["query"]: