refactor: streamline database session usage in batch_create_segment_to_index_task (#26795)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Guangdong Liu 2025-10-14 09:22:48 +08:00 committed by GitHub
parent 7b8540281a
commit a3b33cbe28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 39 additions and 45 deletions

View File

@ -8,7 +8,6 @@ import click
import pandas as pd
from celery import shared_task
from sqlalchemy import func
from sqlalchemy.orm import Session
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -50,54 +49,48 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
try:
with Session(db.engine) as session:
dataset = session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
dataset_document = session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
dataset_document = db.session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
raise ValueError("Document is not available.")
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
raise ValueError("Document is not available.")
upload_file = session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
upload_file = db.session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
# Skip the first row
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
@ -105,6 +98,7 @@ def batch_create_segment_to_index_task(
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
@ -135,11 +129,11 @@ def batch_create_segment_to_index_task(
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
# update document word count
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")