refactor(api): type default_retrieval_model with DefaultRetrievalModelDict in core/rag (#33676)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
BitToby 2026-03-18 15:47:51 +02:00 committed by GitHub
parent 29c70736dc
commit 25ab5e46b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 7 deletions

View File

@ -68,9 +68,12 @@ class SegmentRecord(TypedDict):
class DefaultRetrievalModelDict(TypedDict):
search_method: RetrievalMethod | str
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModelDict
reranking_mode: NotRequired[str]
weights: NotRequired[WeightsDict | None]
score_threshold: NotRequired[float]
top_k: int
score_threshold_enabled: bool

View File

@ -33,7 +33,7 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
@ -87,7 +87,7 @@ from models.enums import CreatorUserRole, DatasetQuerySource
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
default_retrieval_model: dict[str, Any] = {
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -666,7 +666,11 @@ class DatasetRetrieval:
document_ids_filter = document_ids
else:
return []
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
retrieval_model_config: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
if dataset.retrieval_model
else default_retrieval_model
)
# get top k
top_k = retrieval_model_config["top_k"]
@ -1058,7 +1062,11 @@ class DatasetRetrieval:
all_documents.append(document)
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_model: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
if dataset.retrieval_model
else default_retrieval_model
)
if dataset.indexing_technique == "economy":
# use keyword table query
@ -1132,7 +1140,7 @@ class DatasetRetrieval:
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
default_retrieval_model: DefaultRetrievalModelDict = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -1141,7 +1149,11 @@ class DatasetRetrieval:
}
for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
retrieval_model_config: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
if dataset.retrieval_model
else default_retrieval_model
)
# get top k
top_k = retrieval_model_config["top_k"]