diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 4a731bf277..a487c49053 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -3,8 +3,7 @@ import logging import re import uuid -from collections.abc import Mapping -from typing import Any, cast +from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) @@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService _file_access_controller = DatabaseFileAccessController() +class ParagraphFormatPreviewDict(TypedDict): + chunk_structure: str + preview: list[dict[str, Any]] + total_segments: int + + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword = Keyword(dataset) keyword.add_texts(documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return { + result: ParagraphFormatPreviewDict = { "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks), } + return result else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 53596b5de8..2db233874a 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -3,8 +3,7 @@ import json import logging import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict from sqlalchemy import delete, select @@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class ParentChildFormatPreviewDict(TypedDict): + chunk_structure: str + parent_mode: str + preview: list[dict[str, Any]] + total_segments: int + + class ParentChildIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -351,17 +357,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict: parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) - return { + result: ParentChildFormatPreviewDict = { "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + return result def generate_summary_preview( self, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 273ea0f852..b0f7928092 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,8 +4,7 @@ import logging import re import threading import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict import pandas as pd from flask import Flask, current_app @@ -36,6 +35,12 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class QAFormatPreviewDict(TypedDict): + chunk_structure: str + qa_preview: list[dict[str, Any]] + total_segments: int + + class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -230,16 +235,17 @@ class QAIndexProcessor(BaseIndexProcessor): else: raise ValueError("Indexing technique must be high quality.") - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> QAFormatPreviewDict: qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) - return { + result: QAFormatPreviewDict = { "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } + return result def generate_summary_preview( self,