refactor(api): type bare dict/list annotations in remaining rag folder (#33775)

This commit is contained in:
BitToby 2026-03-19 20:31:06 +02:00 committed by GitHub
parent 5b9cb55c45
commit f40f6547b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 42 additions and 33 deletions

View File

@ -1,9 +1,10 @@
import re
from typing import Any
class CleanProcessor:
@classmethod
def clean(cls, text: str, process_rule: dict) -> str:
def clean(cls, text: str, process_rule: dict[str, Any] | None) -> str:
# default clean
# remove invalid symbol
text = re.sub(r"<\|", "<", text)

View File

@ -4,6 +4,7 @@ from typing import Any
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@ -15,6 +16,11 @@ from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class PreSegmentData(TypedDict):
segment: DocumentSegment
keywords: list[str]
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
@ -128,7 +134,7 @@ class Jieba(BaseKeyword):
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
def _save_dataset_keyword_table(self, keyword_table: dict[str, set[str]] | None):
keyword_table_dict = {
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
@ -144,7 +150,7 @@ class Jieba(BaseKeyword):
storage.delete(file_key)
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))
def _get_dataset_keyword_table(self) -> dict | None:
def _get_dataset_keyword_table(self) -> dict[str, set[str]] | None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
@ -169,14 +175,16 @@ class Jieba(BaseKeyword):
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]):
def _add_text_to_keyword_table(
self, keyword_table: dict[str, set[str]], id: str, keywords: list[str]
) -> dict[str, set[str]]:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]):
def _delete_ids_from_keyword_table(self, keyword_table: dict[str, set[str]], ids: list[str]) -> dict[str, set[str]]:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
@ -193,7 +201,7 @@ class Jieba(BaseKeyword):
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
def _retrieve_ids_by_query(self, keyword_table: dict[str, set[str]], query: str, k: int = 4) -> list[str]:
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
@ -228,7 +236,7 @@ class Jieba(BaseKeyword):
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
def multi_create_segment_keywords(self, pre_segment_data_list: list[PreSegmentData]):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:

View File

@ -103,7 +103,7 @@ class RetrievalService:
reranking_mode: str = "reranking_model",
weights: WeightsDict | None = None,
document_ids_filter: list[str] | None = None,
attachment_ids: list | None = None,
attachment_ids: list[str] | None = None,
):
if not query and not attachment_ids:
return []
@ -250,8 +250,8 @@ class RetrievalService:
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
all_documents: list[Document],
exceptions: list[str],
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():
@ -279,9 +279,9 @@ class RetrievalService:
top_k: int,
score_threshold: float | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
all_documents: list[Document],
retrieval_method: RetrievalMethod,
exceptions: list,
exceptions: list[str],
document_ids_filter: list[str] | None = None,
query_type: QueryType = QueryType.TEXT_QUERY,
):
@ -373,9 +373,9 @@ class RetrievalService:
top_k: int,
score_threshold: float | None,
reranking_model: RerankingModelDict | None,
all_documents: list,
all_documents: list[Document],
retrieval_method: str,
exceptions: list,
exceptions: list[str],
document_ids_filter: list[str] | None = None,
):
with flask_app.app_context():

View File

@ -366,7 +366,7 @@ class WordExtractor(BaseExtractor):
paragraph_content = []
# State for legacy HYPERLINK fields
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
hyperlink_field_text_parts: list[str] = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:

View File

@ -591,7 +591,7 @@ class DatasetRetrieval:
user_id: str,
user_from: str,
query: str,
available_datasets: list,
available_datasets: list[Dataset],
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
@ -633,15 +633,15 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if dataset:
selected_dataset = db.session.scalar(dataset_stmt)
if selected_dataset:
results = []
if dataset.provider == "external":
if selected_dataset.provider == "external":
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
tenant_id=selected_dataset.tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
external_retrieval_parameters=selected_dataset.retrieval_model,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
@ -654,28 +654,28 @@ class DatasetRetrieval:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
document.metadata["dataset_name"] = selected_dataset.name
results.append(document)
else:
if metadata_condition and not metadata_filter_document_ids:
return []
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
document_ids = metadata_filter_document_ids.get(selected_dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
return []
retrieval_model_config: DefaultRetrievalModelDict = (
cast(DefaultRetrievalModelDict, dataset.retrieval_model)
if dataset.retrieval_model
cast(DefaultRetrievalModelDict, selected_dataset.retrieval_model)
if selected_dataset.retrieval_model
else default_retrieval_model
)
# get top k
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
if selected_dataset.indexing_technique == "economy":
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else:
retrieval_method = retrieval_model_config["search_method"]
@ -694,7 +694,7 @@ class DatasetRetrieval:
with measure_time() as timer:
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
dataset_id=selected_dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
@ -726,7 +726,7 @@ class DatasetRetrieval:
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
available_datasets: list[Dataset],
query: str | None,
top_k: int,
score_threshold: float,
@ -1028,7 +1028,7 @@ class DatasetRetrieval:
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
all_documents: list[Document],
document_ids_filter: list[str] | None = None,
metadata_condition: MetadataCondition | None = None,
attachment_ids: list[str] | None = None,
@ -1298,7 +1298,7 @@ class DatasetRetrieval:
def get_metadata_filter_condition(
self,
dataset_ids: list,
dataset_ids: list[str],
query: str,
tenant_id: str,
user_id: str,
@ -1400,7 +1400,7 @@ class DatasetRetrieval:
return output
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
self, dataset_ids: list[str], query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> list[dict[str, Any]] | None:
# get all metadata field
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
@ -1598,7 +1598,7 @@ class DatasetRetrieval:
)
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list[str], query: str
):
model_mode = ModelMode(mode)
input_text = query
@ -1690,7 +1690,7 @@ class DatasetRetrieval:
def _multiple_retrieve_thread(
self,
flask_app: Flask,
available_datasets: list,
available_datasets: list[Dataset],
metadata_condition: MetadataCondition | None,
metadata_filter_document_ids: dict[str, list[str]] | None,
all_documents: list[Document],