refactor: replace untyped dicts with TypedDict in VDB config classes (#34697)

This commit is contained in:
Renzo 2026-04-07 19:57:11 -05:00 committed by GitHub
parent ae9fcc2969
commit 2127d5850f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 26 deletions

View File

@ -73,7 +73,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
try:
# Fetch dataset in a fresh session to avoid DetachedInstanceError
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
else:

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -27,7 +28,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
start_at = time.perf_counter()
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
@ -39,11 +40,10 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
try:
# update segment status to indexing
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: SegmentStatus.INDEXING,
DocumentSegment.indexing_at: naive_utc_now(),
}
session.execute(
update(DocumentSegment)
.where(DocumentSegment.id == segment.id)
.values(status=SegmentStatus.INDEXING, indexing_at=naive_utc_now())
)
session.commit()
document = Document(
@ -81,11 +81,10 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
index_processor.load(dataset, [document])
# update segment to completed
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.completed_at: naive_utc_now(),
}
session.execute(
update(DocumentSegment)
.where(DocumentSegment.id == segment.id)
.values(status=SegmentStatus.COMPLETED, completed_at=naive_utc_now())
)
session.commit()

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -30,12 +30,12 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
"""
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1))
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
@ -123,17 +123,14 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
except Exception as e:
logger.exception("enable segments to index failed")
# update segment error msg
session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": naive_utc_now(),
"enabled": False,
}
session.execute(
update(DocumentSegment)
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.values(error=str(e), status="error", disabled_at=naive_utc_now(), enabled=False)
)
session.commit()
finally: