diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 12053377e2..8760d60de0 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -4,7 +4,7 @@ import logging import time import uuid from datetime import UTC, datetime -from typing import Any +from typing import TypedDict, cast from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.model_entities import ModelType @@ -25,6 +25,22 @@ from models.enums import SummaryStatus logger = logging.getLogger(__name__) +class SummaryEntryDict(TypedDict): + segment_id: str + segment_position: int + status: str + summary_preview: str | None + error: str | None + created_at: int | None + updated_at: int | None + + +class DocumentSummaryStatusDetailDict(TypedDict): + total_segments: int + summary_status: dict[str, int] + summaries: list[SummaryEntryDict] + + class SummaryIndexService: """Service for generating and managing summary indexes.""" @@ -1352,7 +1368,7 @@ class SummaryIndexService: def get_document_summary_status_detail( document_id: str, dataset_id: str, - ) -> dict[str, Any]: + ) -> DocumentSummaryStatusDetailDict: """ Get detailed summary status for a document. @@ -1403,7 +1419,7 @@ class SummaryIndexService: SummaryStatus.NOT_STARTED: 0, } - summary_list = [] + summary_list: list[SummaryEntryDict] = [] for segment in segments: summary = summary_map.get(segment.id) if summary: @@ -1438,8 +1454,8 @@ class SummaryIndexService: } ) - return { - "total_segments": total_segments, - "summary_status": status_counts, - "summaries": summary_list, - } + return DocumentSummaryStatusDetailDict( + total_segments=total_segments, + summary_status=cast(dict[str, int], status_counts), + summaries=summary_list, + )