mirror of https://github.com/langgenius/dify.git
refactor: partition Celery task sessions into smaller, discrete execu… (#32085)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b035b091fa
commit
55de893984
|
|
@ -6,7 +6,6 @@ from celery import shared_task
|
|||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
|
@ -58,5 +57,3 @@ def add_annotation_to_index_task(
|
|||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import click
|
|||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
|
@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
|||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Annotation deleted index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from celery import shared_task
|
|||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
|
|
@ -59,5 +58,3 @@ def update_annotation_to_index_task(
|
|||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
|
|||
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
# Initialize variables with default values
|
||||
upload_file_key: str | None = None
|
||||
dataset_config: dict | None = None
|
||||
document_config: dict | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
|
|
@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
|
|||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
dataset_config = {
|
||||
"id": dataset.id,
|
||||
"indexing_technique": dataset.indexing_technique,
|
||||
"tenant_id": dataset.tenant_id,
|
||||
"embedding_model_provider": dataset.embedding_model_provider,
|
||||
"embedding_model": dataset.embedding_model,
|
||||
}
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
document_config = {
|
||||
"id": dataset_document.id,
|
||||
"doc_form": dataset_document.doc_form,
|
||||
"word_count": dataset_document.word_count or 0,
|
||||
}
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
upload_file_key = upload_file.key
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
# Ensure required variables are set before proceeding
|
||||
if upload_file_key is None or dataset_config is None or document_config is None:
|
||||
logger.error("Required configuration not set due to session error")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file_key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file_key, file_path)
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset_config["indexing_technique"] == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset_config["tenant_id"],
|
||||
provider=dataset_config["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset_config["embedding_model"],
|
||||
)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document_config["id"])
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if dataset_document:
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
session.add(dataset_document)
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset:
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
|
||||
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||
"""
|
||||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
total_attachment_files = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
|
|
@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||
SegmentAttachmentBinding.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
segment_contents = [segment.content for segment in segments]
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
return
|
||||
|
||||
# check segment is exist
|
||||
if index_node_ids:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
total_image_files = []
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment_content in segment_contents:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment_content)
|
||||
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
|
||||
total_image_files.extend([image_file.key for image_file in image_files])
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
session.delete(segment)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
session.commit()
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
for image_file_key in total_image_files:
|
||||
try:
|
||||
storage.delete(image_file_key)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file_key,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete segment attachments
|
||||
if attachment_ids:
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
if binding_ids:
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
for attachment_file_key in total_attachment_files:
|
||||
try:
|
||||
storage.delete(attachment_file_key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
attachment_file_key,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
# Phase 1: Update status to parsing (short transaction)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
documents = (
|
||||
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
session.commit()
|
||||
# Transaction committed and closed
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
|
||||
has_error = False
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
has_error = True
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
has_error = True
|
||||
|
||||
if not has_error:
|
||||
with session_factory.create_session() as session:
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
|
|
@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
# expire all session to get latest document's indexing status
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
documents = (
|
||||
session.query(Document)
|
||||
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
|
|
@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document_id,
|
||||
document.id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document_id,
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: "
|
||||
"status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
logger.warning("Document %s not found after indexing", document.id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
|
|||
"""
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
|
|
@ -17,6 +16,6 @@ def save_workflow_execution_task(
|
|||
self,
|
||||
deletions: list[DraftVarFileDeletion],
|
||||
):
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
srv = WorkflowDraftVariableService(session=session)
|
||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||
|
|
|
|||
|
|
@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
|
|||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
# Execute the task
|
||||
# Execute the task - should raise ValueError for empty CSV
|
||||
job_id = str(uuid.uuid4())
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
with pytest.raises(ValueError, match="The CSV file is empty"):
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
# Verify error handling
|
||||
# Check Redis cache was set to error status
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
cache_key = f"segment_batch_import_{job_id}"
|
||||
cache_value = redis_client.get(cache_key)
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
# Since exception was raised, no segments should be created
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
|
|
|
|||
|
|
@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
|
|||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
||||
sessions = [] # Track all created sessions
|
||||
# Shared mock data that all sessions will access
|
||||
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
def create_session_side_effect():
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
# Track commit calls
|
||||
commit_mock = MagicMock()
|
||||
session.commit = commit_mock
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# Support session.begin() for transactions
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def begin_exit_side_effect(*args, **kwargs):
|
||||
# Auto-commit on transaction exit (like SQLAlchemy)
|
||||
session.commit()
|
||||
# Also mark wrapper's commit as called
|
||||
if sessions:
|
||||
sessions[0].commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
|
||||
session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
sessions.append(session)
|
||||
|
||||
# Setup query with side_effect to handle both Dataset and Document queries
|
||||
def query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
# Support both .first() and .all() calls with chaining
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
# Create an iterator for .first() calls if not exists
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
session.query = MagicMock(side_effect=query_side_effect)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = create_session_side_effect
|
||||
|
||||
# Create a wrapper that behaves like the first session but has access to all sessions
|
||||
class SessionWrapper:
|
||||
def __init__(self):
|
||||
self._sessions = sessions
|
||||
self._shared_data = shared_mock_data
|
||||
# Create a default session for setup phase
|
||||
self._default_session = MagicMock()
|
||||
self._default_session.close = MagicMock()
|
||||
self._default_session.commit = MagicMock()
|
||||
|
||||
# Support session.begin() for default session too
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = self._default_session
|
||||
|
||||
def default_begin_exit_side_effect(*args, **kwargs):
|
||||
self._default_session.commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
|
||||
self._default_session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
def default_query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Forward all attribute access to the first session, or default if none created yet
|
||||
target_session = self._sessions[0] if self._sessions else self._default_session
|
||||
return getattr(target_session, name)
|
||||
|
||||
@property
|
||||
def all_sessions(self):
|
||||
"""Access all created sessions for testing."""
|
||||
return self._sessions
|
||||
|
||||
wrapper = SessionWrapper()
|
||||
yield wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -252,18 +356,9 @@ class TestTaskEnqueuing:
|
|||
use the deprecated function.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -304,21 +399,9 @@ class TestBatchProcessing:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create an iterator for documents
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -357,19 +440,9 @@ class TestBatchProcessing:
|
|||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
|
|
@ -407,19 +480,9 @@ class TestBatchProcessing:
|
|||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
|
@ -444,7 +507,10 @@ class TestBatchProcessing:
|
|||
"""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Set shared mock data with empty documents list
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = []
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -482,19 +548,9 @@ class TestProgressTracking:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -528,19 +584,9 @@ class TestProgressTracking:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -635,19 +681,9 @@ class TestErrorHandling:
|
|||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set up to trigger vector space limit error
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
|
|
@ -674,17 +710,9 @@ class TestErrorHandling:
|
|||
Errors during indexing should be caught and logged, but not crash the task.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
||||
|
|
@ -708,17 +736,9 @@ class TestErrorHandling:
|
|||
but not treated as a failure.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise DocumentIsPausedError
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
||||
|
|
@ -853,17 +873,9 @@ class TestTaskCancellation:
|
|||
Session cleanup should happen in finally block.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -883,17 +895,9 @@ class TestTaskCancellation:
|
|||
Session cleanup should happen even when errors occur.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Test error")
|
||||
|
|
@ -962,6 +966,7 @@ class TestAdvancedScenarios:
|
|||
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
# Create only 2 documents (simulate one missing)
|
||||
# The new code uses .all() which will only return existing documents
|
||||
mock_documents = []
|
||||
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
||||
doc = MagicMock(spec=Document)
|
||||
|
|
@ -971,21 +976,9 @@ class TestAdvancedScenarios:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create iterator that returns None for missing document
|
||||
doc_responses = [mock_documents[0], None, mock_documents[1]]
|
||||
doc_iter = iter(doc_responses)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data - .all() will only return existing documents
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
|
|||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set vector space exactly at limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
|
|
@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Billing disabled - limits should not be checked
|
||||
mock_feature_service.get_features.return_value.billing.enabled = False
|
||||
|
|
@ -1273,19 +1246,9 @@ class TestIntegration:
|
|||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -1321,19 +1284,9 @@ class TestIntegration:
|
|||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -1415,17 +1368,9 @@ class TestEdgeCases:
|
|||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -1465,17 +1410,9 @@ class TestEdgeCases:
|
|||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -1555,19 +1492,9 @@ class TestEdgeCases:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set vector space limit to 0 (unlimited)
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
|
|
@ -1612,19 +1539,9 @@ class TestEdgeCases:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set negative vector space limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
|
|
@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Configure billing with sufficient limits
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
|
|
@ -1826,19 +1733,9 @@ class TestRobustness:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
||||
|
|
@ -1866,7 +1763,7 @@ class TestRobustness:
|
|||
- No exceptions occur
|
||||
|
||||
Expected behavior:
|
||||
- Database session is closed
|
||||
- All database sessions are closed
|
||||
- No connection leaks
|
||||
"""
|
||||
# Arrange
|
||||
|
|
@ -1879,19 +1776,9 @@ class TestRobustness:
|
|||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
|
@ -1899,10 +1786,11 @@ class TestRobustness:
|
|||
# Act
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert mock_db_session.close.called
|
||||
# Verify close is called exactly once
|
||||
assert mock_db_session.close.call_count == 1
|
||||
# Assert - All created sessions should be closed
|
||||
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
|
||||
assert len(mock_db_session.all_sessions) >= 1
|
||||
for session in mock_db_session.all_sessions:
|
||||
assert session.close.called, "All sessions should be closed"
|
||||
|
||||
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue