mirror of https://github.com/langgenius/dify.git
refactor: streamline database session usage in batch_create_segment_to_index_task (#26795)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
7b8540281a
commit
a3b33cbe28
|
|
@ -8,7 +8,6 @@ import click
|
|||
import pandas as pd
|
||||
from celery import shared_task
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -50,54 +49,48 @@ def batch_create_segment_to_index_task(
|
|||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
dataset_document = db.session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
raise ValueError("Document is not available.")
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
raise ValueError("Document is not available.")
|
||||
|
||||
upload_file = session.get(UploadFile, upload_file_id)
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
upload_file = db.session.get(UploadFile, upload_file_id)
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
|
|
@ -105,6 +98,7 @@ def batch_create_segment_to_index_task(
|
|||
)
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
|
@ -135,11 +129,11 @@ def batch_create_segment_to_index_task(
|
|||
word_count_change += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
# update document word count
|
||||
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
db.session.add(dataset_document)
|
||||
# add index to db
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
|
|
|
|||
Loading…
Reference in New Issue