diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 7f6ecc3d3f..d7ea03efee 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -68,9 +68,12 @@ class SegmentRecord(TypedDict): class DefaultRetrievalModelDict(TypedDict): - search_method: RetrievalMethod | str + search_method: RetrievalMethod reranking_enable: bool reranking_model: RerankingModelDict + reranking_mode: NotRequired[str] + weights: NotRequired[WeightsDict | None] + score_threshold: NotRequired[float] top_k: int score_threshold_enabled: bool diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c44e9b847b..1096c69041 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -33,7 +33,7 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition @@ -87,7 +87,7 @@ from models.enums import CreatorUserRole, DatasetQuerySource from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService -default_retrieval_model: dict[str, Any] = { +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -666,7 +666,11 @@ class DatasetRetrieval: document_ids_filter = document_ids else: return [] - retrieval_model_config = dataset.retrieval_model or default_retrieval_model + retrieval_model_config: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, dataset.retrieval_model) + if dataset.retrieval_model + else default_retrieval_model + ) # get top k top_k = retrieval_model_config["top_k"] @@ -1058,7 +1062,11 @@ class DatasetRetrieval: all_documents.append(document) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, dataset.retrieval_model) + if dataset.retrieval_model + else default_retrieval_model + ) if dataset.indexing_technique == "economy": # use keyword table query @@ -1132,7 +1140,7 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config - default_retrieval_model = { + default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -1141,7 +1149,11 @@ class DatasetRetrieval: } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model or default_retrieval_model + retrieval_model_config: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, dataset.retrieval_model) + if dataset.retrieval_model + else default_retrieval_model + ) # get top k top_k = retrieval_model_config["top_k"]