refactor: use session factory instead of call db.session directly (#31198)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2026-01-21 13:43:06 +08:00 committed by GitHub
parent 071bbc6d74
commit 121d301a41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
48 changed files with 2788 additions and 2693 deletions

View File

@ -3,8 +3,8 @@ from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
from sqlalchemy.orm import Session, sessionmaker
from core.db.session_factory import session_factory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
@ -31,13 +31,11 @@ class TriggerPostLayer(GraphEngineLayer):
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
self.session_maker = session_maker
def on_graph_start(self):
pass
@ -47,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer):
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
with self.session_maker() as session:
with session_factory.create_session() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:

View File

@ -35,7 +35,6 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@ -473,6 +472,9 @@ class TraceTask:
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
# Lazy import to avoid circular import during module initialization
from repositories.factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo

View File

@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
@ -28,106 +28,106 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
start_at = time.perf_counter()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
db.session.close()
return
with session_factory.create_session() as session:
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
return
if dataset_document.indexing_status != "completed":
db.session.close()
return
if dataset_document.indexing_status != "completed":
return
indexing_cache_key = f"document_{dataset_document.id}_indexing"
indexing_cache_key = f"document_{dataset_document.id}_indexing"
try:
dataset = dataset_document.dataset
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
try:
dataset = dataset_document.dataset
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
)
.order_by(DocumentSegment.position.asc())
.all()
)
.order_by(DocumentSegment.position.asc())
.all()
)
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
session.query(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
# update segment to enable
session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: naive_utc_now(),
}
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
session.commit()
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
# update segment to enable
db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
)
except Exception as e:
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
)
except Exception as e:
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -5,9 +5,9 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -32,74 +32,72 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
active_jobs_key = f"annotation_import_active:{tenant_id}"
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
with session_factory.create_session() as session:
# get app info
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
)
session.add(annotation)
session.flush()
document = Document(
page_content=content["question"],
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
db.session.add(annotation)
db.session.flush()
document = Document(
page_content=content["question"],
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id, "annotation"
)
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
if app_annotation_setting:
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id, "annotation"
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.create(documents, duplicate_check=True)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
"Build index successful for batch import annotation: {} latency: {}".format(
job_id, end_at - start_at
),
fg="green",
)
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.create(documents, duplicate_check=True)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
"Build index successful for batch import annotation: {} latency: {}".format(
job_id, end_at - start_at
),
fg="green",
)
)
except Exception as e:
db.session.rollback()
redis_client.setex(indexing_cache_key, 600, "error")
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
redis_client.setex(indexing_error_msg_key, 600, str(e))
logger.exception("Build index for batch import annotations failed")
finally:
# Clean up active job tracking to release concurrency slot
try:
redis_client.zrem(active_jobs_key, job_id)
logger.debug("Released concurrency slot for job: %s", job_id)
except Exception as cleanup_error:
# Log but don't fail if cleanup fails - the job will be auto-expired
logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
# Close database session
db.session.close()
except Exception as e:
session.rollback()
redis_client.setex(indexing_cache_key, 600, "error")
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
redis_client.setex(indexing_error_msg_key, 600, str(e))
logger.exception("Build index for batch import annotations failed")
finally:
# Clean up active job tracking to release concurrency slot
try:
redis_client.zrem(active_jobs_key, job_id)
logger.debug("Released concurrency slot for job: %s", job_id)
except Exception as cleanup_error:
# Log but don't fail if cleanup fails - the job will be auto-expired
logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)

View File

@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import exists, select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
with session_factory.create_session() as session:
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
return
app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if not app_annotation_setting:
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
db.session.close()
return
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
try:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
collection_binding_id=app_annotation_setting.collection_binding_id,
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if not app_annotation_setting:
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
return
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
try:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception:
logger.exception("Delete annotation index failed when annotation deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
collection_binding_id=app_annotation_setting.collection_binding_id,
)
# delete annotation setting
db.session.delete(app_annotation_setting)
db.session.commit()
try:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception:
logger.exception("Delete annotation index failed when annotation deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("Annotation batch deleted index failed")
redis_client.setex(disable_app_annotation_job_key, 600, "error")
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)
db.session.close()
# delete annotation setting
session.delete(app_annotation_setting)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"App annotations index deleted : {app_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception as e:
logger.exception("Annotation batch deleted index failed")
redis_client.setex(disable_app_annotation_job_key, 600, "error")
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)

View File

@ -5,9 +5,9 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
@ -33,92 +33,98 @@ def enable_annotation_reply_task(
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
with session_factory.create_session() as session:
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
return
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
)
)
if old_dataset_collection_binding and annotations:
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
)
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
old_vector.delete()
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = naive_utc_now()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id,
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
db.session.add(new_app_annotation_setting)
annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
)
)
if old_dataset_collection_binding and annotations:
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
old_vector.delete()
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = naive_utc_now()
session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id,
)
documents.append(document)
session.add(new_app_annotation_setting)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
vector.delete_by_metadata_field("app_id", app_id)
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
vector.create(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("Annotation batch created index failed")
redis_client.setex(enable_app_annotation_job_key, 600, "error")
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
db.session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)
db.session.close()
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
vector.delete_by_metadata_field("app_id", app_id)
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
vector.create(documents)
session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"App annotations added to index: {app_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception as e:
logger.exception("Annotation batch created index failed")
redis_client.setex(enable_app_annotation_job_key, 600, "error")
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)

View File

@ -10,13 +10,13 @@ from typing import Any
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.trigger_post_layer import TriggerPostLayer
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
@ -98,10 +98,7 @@ def _execute_workflow_common(
):
"""Execute workflow with common logic and trigger log updates."""
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
# Get trigger log
@ -157,7 +154,7 @@ def _execute_workflow_common(
root_node_id=trigger_data.root_node_id,
graph_engine_layers=[
# TODO: Re-enable TimeSliceLayer after the HITL release.
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
],
)

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
@ -28,65 +28,64 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
"""
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
if not doc_form:
raise ValueError("doc_form is required")
try:
if not doc_form:
raise ValueError("doc_form is required")
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
if not dataset:
raise Exception("Document has no dataset")
db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
storage.delete(file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
db.session.commit()
if file_ids:
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
db.session.delete(file)
session.commit()
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned documents when documents deleted failed")
finally:
db.session.close()
except Exception:
logger.exception("Cleaned documents when documents deleted failed")

View File

@ -9,9 +9,9 @@ import pandas as pd
from celery import shared_task
from sqlalchemy import func
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
@ -48,104 +48,107 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
try:
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
with session_factory.create_session() as session:
try:
dataset = session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
dataset_document = db.session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
dataset_document = session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
raise ValueError("Document is not available.")
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
raise ValueError("Document is not available.")
upload_file = db.session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
upload_file = session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
session.add(dataset_document)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
finally:
db.session.close()
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import WorkflowType
from models.dataset import (
@ -53,135 +53,155 @@ def clean_dataset_task(
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = Dataset(
id=dataset_id,
tenant_id=tenant_id,
indexing_technique=indexing_technique,
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexStructureType
doc_form = IndexStructureType.PARAGRAPH_INDEX
logger.info(
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
with session_factory.create_session() as session:
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception:
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logger.info(
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
dataset = Dataset(
id=dataset_id,
tenant_id=tenant_id,
indexing_technique=indexing_technique,
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
)
).all()
if documents is None or len(documents) == 0:
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else:
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexStructureType
for document in documents:
db.session.delete(document)
# delete document file
doc_form = IndexStructureType.PARAGRAPH_INDEX
logger.info(
click.style(
f"Invalid doc_form detected, using default index type for cleanup: {doc_form}",
fg="yellow",
)
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception:
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logger.info(
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
)
if documents is None or len(documents) == 0:
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else:
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
for document in documents:
session.delete(document)
segment_ids = [segment.id for segment in segments]
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(image_file.key)
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(image_file)
db.session.delete(segment)
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete pipeline and workflow
if pipeline_id:
db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
db.session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
# delete files
if documents:
for document in documents:
try:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete pipeline and workflow
if pipeline_id:
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
# delete files
if documents:
file_ids = []
for document in documents:
if document.data_source_type == "upload_file":
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)
if not file:
continue
storage.delete(file.key)
db.session.delete(file)
except Exception:
continue
file_ids.append(file_id)
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
for file in files:
storage.delete(file.key)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
)
except Exception:
# Add rollback to prevent dirty session state in case of exceptions
# This ensures the database session is properly cleaned up
try:
db.session.rollback()
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(file_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Failed to rollback database session")
# Add rollback to prevent dirty session state in case of exceptions
# This ensures the database session is properly cleaned up
try:
session.rollback()
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
except Exception:
logger.exception("Failed to rollback database session")
logger.exception("Cleaned dataset when dataset deleted failed")
finally:
db.session.close()
logger.exception("Cleaned dataset when dataset deleted failed")
finally:
# Explicitly close the session for test expectations and safety
try:
session.close()
except Exception:
logger.exception("Failed to close database session")

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
@ -29,85 +29,94 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
session.delete(segment)
session.commit()
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(image_file.key)
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(image_file)
db.session.delete(segment)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
db.session.commit()
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
db.session.delete(file)
db.session.commit()
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
db.session.commit()
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned document when document deleted failed")
finally:
db.session.close()
except Exception:
logger.exception("Cleaned document when document deleted failed")

View File

@ -3,10 +3,10 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@ -24,37 +24,37 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
for document_id in document_ids:
document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for document_id in document_ids:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")
finally:
db.session.close()
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")

View File

@ -4,9 +4,9 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "waiting":
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
# update segment status to indexing
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: naive_utc_now(),
}
)
db.session.commit()
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
if segment.status != "waiting":
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
indexing_cache_key = f"segment_{segment.id}_indexing"
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, [document])
try:
# update segment status to indexing
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: naive_utc_now(),
}
)
session.commit()
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
# update segment to completed
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
dataset = segment.dataset
end_at = time.perf_counter()
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, [document])
# update segment to completed
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.completed_at: naive_utc_now(),
}
)
session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -4,11 +4,11 @@ import time
import click
from celery import shared_task # type: ignore
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
.all()
)
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
documents.append(document)
# save vector index
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
.order_by(DocumentSegment.position.asc())
.all()
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
end_at = time.perf_counter()
logging.info(
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
)
except Exception:
logging.exception("Deal dataset vector index failed")
finally:
db.session.close()
documents.append(document)
# save vector index
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
elif action == "update":
dataset_documents = (
session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(
click.style(
"Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at),
fg="green",
)
)
except Exception:
logging.exception("Deal dataset vector index failed")

View File

@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
.order_by(DocumentSegment.position.asc())
.all()
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
end_at = time.perf_counter()
logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Deal dataset vector index failed")
finally:
db.session.close()
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
elif action == "update":
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logger.info(
click.style(
f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Deal dataset vector index failed")

View File

@ -3,7 +3,7 @@ import logging
from celery import shared_task
from configs import dify_config
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models import Account
from services.billing_service import BillingService
from tasks.mail_account_deletion_task import send_deletion_success_task
@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
account = db.session.query(Account).where(Account.id == account_id).first()
try:
if dify_config.BILLING_ENABLED:
BillingService.delete_account(account_id)
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise
with session_factory.create_session() as session:
account = session.query(Account).where(Account.id == account_id).first()
try:
if dify_config.BILLING_ENABLED:
BillingService.delete_account(account_id)
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise
if not account:
logger.error("Account %s not found.", account_id)
return
# send success email
send_deletion_success_task.delay(account.email)
if not account:
logger.error("Account %s not found.", account_id)
return
# send success email
send_deletion_success_task.delay(account.email)

View File

@ -4,7 +4,7 @@ import time
import click
from celery import shared_task
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models import ConversationVariable
from models.model import Message, MessageAnnotation, MessageFeedback
from models.tools import ToolConversationVariables, ToolFile
@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str):
)
start_at = time.perf_counter()
try:
db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(ToolConversationVariables).where(
ToolConversationVariables.conversation_id == conversation_id
).delete(synchronize_session=False)
db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
fg="green",
with session_factory.create_session() as session:
try:
session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
synchronize_session=False
)
)
except Exception as e:
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
db.session.rollback()
raise e
finally:
db.session.close()
session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.query(ToolConversationVariables).where(
ToolConversationVariables.conversation_id == conversation_id
).delete(synchronize_session=False)
session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
(
f"Succeeded cleaning data from db for conversation_id {conversation_id} "
f"latency: {end_at - start_at}"
),
fg="green",
)
)
except Exception:
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
session.rollback()
raise

View File

@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
@ -26,49 +26,52 @@ def delete_segment_from_index_task(
"""
logger.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
if not dataset_document:
return
dataset_document = session.query(Document).where(Document.id == document_id).first()
if not dataset_document:
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info("Document not in valid state for index operations, skipping")
return
doc_form = dataset_document.doc_form
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logging.info("Document not in valid state for index operations, skipping")
return
doc_form = dataset_document.doc_form
# Proceed with index cleanup using the index_node_ids directly
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
index_node_ids,
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
# Proceed with index cleanup using the index_node_ids directly
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
index_node_ids,
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
db.session.delete(binding)
# delete upload file
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
db.session.commit()
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
session.delete(binding)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("delete segment from index failed")
finally:
db.session.close()
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("delete segment from index failed")

View File

@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@ -23,46 +23,53 @@ def disable_segment_from_index_task(segment_id: str):
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
indexing_cache_key = f"segment_{segment.id}_indexing"
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
try:
dataset = segment.dataset
end_at = time.perf_counter()
logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("remove segment from index failed")
segment.enabled = True
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment removed from index: {segment.id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("remove segment from index failed")
segment.enabled = True
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
@ -26,69 +26,65 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
db.session.close()
return
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
db.session.close()
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
).all()
end_at = time.perf_counter()
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()
if not segments:
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)

View File

@ -3,12 +3,12 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.datasource_provider_service import DatasourceProviderService
@ -28,105 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
db.session.commit()
db.session.close()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
last_edited_time = loader.get_notion_last_edited_time()
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
session.commit()
return
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
finally:
db.session.close()
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
last_edited_time = loader.get_notion_last_edited_time()
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)

View File

@ -6,11 +6,11 @@ import click
from celery import shared_task
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@ -46,66 +46,63 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
except Exception as e:
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
db.session.close()
return
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
session.commit()
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(

View File

@ -3,8 +3,9 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
finally:
db.session.close()
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)

View File

@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue(
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
documents = []
documents: list[Document] = []
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
with session_factory.create_session() as session:
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
current = int(getattr(vector_space, "size", 0) or 0)
limit = int(getattr(vector_space, "limit", 0) or 0)
if limit > 0 and (current + count) > limit:
raise ValueError(
"Your total number of documents plus the number of uploads have exceeded the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
current = int(getattr(vector_space, "size", 0) or 0)
limit = int(getattr(vector_space, "limit", 0) or 0)
if limit > 0 and (current + count) > limit:
raise ValueError(
"Your total number of documents plus the number of uploads have exceeded the limit of "
"your subscription."
)
except Exception as e:
documents = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
return
for document in documents:
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
documents = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
if document:
for document in documents:
logger.info(click.style(f"Start process document: {document.id}", fg="green"))
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
session.add(document)
session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
indexing_runner = IndexingRunner()
indexing_runner.run(list(documents))
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
@shared_task(queue="dataset")

View File

@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@ -27,91 +27,93 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
indexing_cache_key = f"segment_{segment.id}_indexing"
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
try:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
multimodel_documents = []
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodel_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
child_documents.append(child_document)
document.children = child_documents
multimodel_documents = []
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodel_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
# save vector index
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
# save vector index
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
@ -29,105 +29,102 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close()
return
try:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
return
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
try:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segments to index failed")
# update segment error msg
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": naive_utc_now(),
"enabled": False,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segments to index failed")
# update segment error msg
session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": naive_utc_now(),
"enabled": False,
}
)
session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)

View File

@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from models.dataset import Document
logger = logging.getLogger(__name__)
@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
try:
indexing_runner = IndexingRunner()
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
indexing_runner.run([document])
elif document.indexing_status == "splitting":
indexing_runner.run_in_splitting_status(document)
elif document.indexing_status == "indexing":
indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter()
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
finally:
db.session.close()
try:
indexing_runner = IndexingRunner()
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
indexing_runner.run([document])
elif document.indexing_status == "splitting":
indexing_runner.run_in_splitting_status(document)
elif document.indexing_status == "indexing":
indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter()
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)

View File

@ -1,14 +1,17 @@
import logging
import time
from collections.abc import Callable
from typing import Any, cast
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from extensions.ext_database import db
from models import (
ApiToken,
@ -77,7 +80,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_workflow_webhook_triggers(tenant_id, app_id)
_delete_workflow_schedule_plans(tenant_id, app_id)
_delete_workflow_trigger_logs(tenant_id, app_id)
end_at = time.perf_counter()
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
except SQLAlchemyError as e:
@ -89,8 +91,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(model_config_id: str):
db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
def del_model_config(session, model_config_id: str):
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@ -101,8 +103,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(site_id: str):
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
def del_site(session, site_id: str):
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
@ -113,8 +115,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(mcp_server_id: str):
db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
def del_mcp_server(session, mcp_server_id: str):
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@ -125,8 +127,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(api_token_id: str):
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
def del_api_token(session, api_token_id: str):
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
@ -137,8 +139,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(installed_app_id: str):
db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
def del_installed_app(session, installed_app_id: str):
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -149,10 +151,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(recommended_app_id: str):
db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
synchronize_session=False
)
def del_recommended_app(session, recommended_app_id: str):
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
@ -163,8 +163,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(annotation_hit_history_id: str):
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
def del_annotation_hit_history(session, annotation_hit_history_id: str):
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
@ -175,8 +175,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
"annotation hit history",
)
def del_annotation_setting(annotation_setting_id: str):
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
def del_annotation_setting(session, annotation_setting_id: str):
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@ -189,8 +189,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(dataset_join_id: str):
db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
def del_dataset_join(session, dataset_join_id: str):
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@ -201,8 +201,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(workflow_id: str):
db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
def del_workflow(session, workflow_id: str):
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -241,10 +241,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):
db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
synchronize_session=False
)
def del_workflow_app_log(session, workflow_app_log_id: str):
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -255,11 +253,11 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(conversation_id: str):
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
def del_conversation(session, conversation_id: str):
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@ -270,28 +268,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
def _delete_conversation_variables(*, app_id: str):
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
with db.engine.connect() as conn:
conn.execute(stmt)
conn.commit()
with session_factory.create_session() as session:
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
session.execute(stmt)
session.commit()
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(message_id: str):
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
def del_message(session, message_id: str):
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
db.session.query(Message).where(Message.id == message_id).delete()
session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
@ -302,8 +298,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(tool_provider_id: str):
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
def del_tool_provider(session, tool_provider_id: str):
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@ -316,8 +312,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(tag_binding_id: str):
db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
def del_tag_binding(session, tag_binding_id: str):
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@ -328,8 +324,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(end_user_id: str):
db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
def del_end_user(session, end_user_id: str):
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -340,10 +336,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(trace_app_config_id: str):
db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
synchronize_session=False
)
def del_trace_app_config(session, trace_app_config_id: str):
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
@ -381,14 +375,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
with db.engine.begin() as conn:
with session_factory.create_session() as session:
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
rows = list(result)
if not rows:
@ -399,7 +393,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
# Clean up associated Offload data first
if file_ids:
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
files_deleted = _delete_draft_variable_offload_data(session, file_ids)
total_files_deleted += files_deleted
# Delete the draft variables
@ -407,8 +401,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
batch_deleted = deleted_result.rowcount
deleted_result = cast(
CursorResult[Any],
session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}),
)
batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0)
total_deleted += batch_deleted
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
@ -423,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
return total_deleted
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
"""
Delete Offload data associated with WorkflowDraftVariable file_ids.
@ -434,7 +431,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
4. Deletes WorkflowDraftVariableFile records
Args:
conn: Database connection
session: Database connection
file_ids: List of WorkflowDraftVariableFile IDs
Returns:
@ -450,12 +447,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
try:
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
query_sql = """
SELECT wdvf.id, uf.key, uf.id as upload_file_id
FROM workflow_draft_variable_files wdvf
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
WHERE wdvf.id IN :file_ids
"""
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
SELECT wdvf.id, uf.key, uf.id as upload_file_id
FROM workflow_draft_variable_files wdvf
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
WHERE wdvf.id IN :file_ids \
"""
result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
file_records = list(result)
# Delete from object storage and collect upload file IDs
@ -473,17 +470,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
# Delete UploadFile records
if upload_file_ids:
delete_upload_files_sql = """
DELETE FROM upload_files
WHERE id IN :upload_file_ids
"""
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
DELETE \
FROM upload_files
WHERE id IN :upload_file_ids \
"""
session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
# Delete WorkflowDraftVariableFile records
delete_variable_files_sql = """
DELETE FROM workflow_draft_variable_files
WHERE id IN :file_ids
"""
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
DELETE \
FROM workflow_draft_variable_files
WHERE id IN :file_ids \
"""
session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
except Exception:
logging.exception("Error deleting draft variable offload data:")
@ -493,8 +492,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
def _delete_app_triggers(tenant_id: str, app_id: str):
def del_app_trigger(trigger_id: str):
db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
def del_app_trigger(session, trigger_id: str):
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -505,8 +504,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def del_plugin_trigger(trigger_id: str):
db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
def del_plugin_trigger(session, trigger_id: str):
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
)
@ -519,8 +518,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def del_webhook_trigger(trigger_id: str):
db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
def del_webhook_trigger(session, trigger_id: str):
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
)
@ -533,10 +532,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def del_schedule_plan(plan_id: str):
db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
synchronize_session=False
)
def del_schedule_plan(session, plan_id: str):
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -547,8 +544,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def del_trigger_log(log_id: str):
db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
def del_trigger_log(session, log_id: str):
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -560,18 +557,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query_sql), params)
if rs.rowcount == 0:
with session_factory.create_session() as session:
rs = session.execute(sa.text(query_sql), params)
rows = rs.fetchall()
if not rows:
break
for i in rs:
for i in rows:
record_id = str(i.id)
try:
delete_func(record_id)
db.session.commit()
delete_func(session, record_id)
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
except Exception:
logger.exception("Error occurred while deleting %s %s", name, record_id)
continue
# continue with next record even if one deletion fails
session.rollback()
break
session.commit()
rs.close()

View File

@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
@ -25,52 +25,55 @@ def remove_document_from_index_task(document_id: str):
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status != "completed":
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
db.session.close()
return
if document.indexing_status != "completed":
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
return
indexing_cache_key = f"document_{document.id}_indexing"
indexing_cache_key = f"document_{document.id}_indexing"
try:
dataset = document.dataset
try:
dataset = document.dataset
if not dataset:
raise Exception("Document has no dataset")
if not dataset:
raise Exception("Document has no dataset")
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logger.exception("clean dataset %s from index failed", dataset.id)
# update segment to disable
db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
DocumentSegment.disabled_at: naive_utc_now(),
DocumentSegment.disabled_by: document.disabled_by,
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logger.exception("clean dataset %s from index failed", dataset.id)
# update segment to disable
session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
DocumentSegment.disabled_at: naive_utc_now(),
DocumentSegment.disabled_by: document.disabled_by,
DocumentSegment.updated_at: naive_utc_now(),
}
)
session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("remove document from index failed")
if not document.archived:
document.enabled = True
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(
click.style(
f"Document removed from index: {document.id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("remove document from index failed")
if not document.archived:
document.enabled = True
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
user = db.session.query(Account).where(Account.id == user_id).first()
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
user = session.query(Account).where(Account.id == user_id).first()
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant.id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant.id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
redis_client.delete(retry_indexing_cache_key)
return
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
if dataset.runtime_mode == "rag_pipeline":
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.retry_error_document(dataset, document, user)
else:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(e)
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
return
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
session.add(document)
session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception(
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
if dataset.runtime_mode == "rag_pipeline":
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.retry_error_document(dataset, document, user)
else:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception(
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
raise e
finally:
db.session.close()
raise e

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
sync_indexing_cache_key = f"document_{document_id}_is_sync"
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
sync_indexing_cache_key = f"document_{document_id}_is_sync"
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(e)
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(sync_indexing_cache_key)
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
session.add(document)
session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(sync_indexing_cache_key)
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))

View File

@ -16,6 +16,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginInvokeError
@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
from extensions.ext_database import db
from models.enums import (
AppTriggerType,
CreatorUserRole,
@ -257,7 +257,7 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
with Session(db.engine) as session:
with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(

View File

@ -7,9 +7,9 @@ from celery import shared_task
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.utils.locks import build_trigger_refresh_lock_key
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
try:
now: int = _now_ts()
with Session(db.engine) as session:
with session_factory.create_session() as session:
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
if not subscription:

View File

@ -10,11 +10,10 @@ import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
@ -46,10 +45,7 @@ def save_workflow_execution_task(
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)

View File

@ -10,13 +10,12 @@ import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@ -48,10 +47,7 @@ def save_workflow_node_execution_task(
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)

View File

@ -1,15 +1,14 @@
import logging
from celery import shared_task
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from core.workflow.nodes.trigger_schedule.exc import (
ScheduleExecutionError,
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
from extensions.ext_database import db
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")

View File

@ -4,8 +4,8 @@ from unittest.mock import patch
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.variables.segments import StringSegment
from extensions.ext_database import db
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
@ -16,26 +16,23 @@ from tasks.remove_app_and_related_data_task import _delete_draft_variables, dele
@pytest.fixture
def app_and_tenant(flask_req_ctx):
tenant_id = uuid.uuid4()
tenant = Tenant(
id=tenant_id,
name="test_tenant",
)
db.session.add(tenant)
with session_factory.create_session() as session:
tenant = Tenant(name="test_tenant")
session.add(tenant)
session.flush()
app = App(
tenant_id=tenant_id, # Now tenant.id will have a value
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.flush()
yield (tenant, app)
app = App(
tenant_id=tenant.id,
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
session.add(app)
session.flush()
# Cleanup with proper error handling
db.session.delete(app)
db.session.delete(tenant)
# return detached objects (ids will be used by tests)
return (tenant, app)
class TestDeleteDraftVariablesIntegration:
@ -44,334 +41,285 @@ class TestDeleteDraftVariablesIntegration:
"""Create test data with apps and draft variables."""
tenant, app = app_and_tenant
# Create a second app for testing
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.commit()
# Create draft variables for both apps
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
with session_factory.create_session() as session:
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(var1)
variables_app1.append(var1)
session.add(app2)
session.flush()
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var2)
variables_app2.append(var2)
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var1)
variables_app1.append(var1)
# Commit all the variables to the database
db.session.commit()
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var2)
variables_app2.append(var2)
session.commit()
app2_id = app2.id
yield {
"app1": app,
"app2": app2,
"app2": App(id=app2_id), # dummy with id to avoid open session
"tenant": tenant,
"variables_app1": variables_app1,
"variables_app2": variables_app2,
}
# Cleanup - refresh session and check if objects still exist
db.session.rollback() # Clear any pending changes
# Clean up remaining variables
cleanup_query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id.in_([app.id, app2_id]))
.execution_options(synchronize_session=False)
)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_query)
# Clean up app2
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
session.execute(cleanup_query)
app2_obj = session.get(App, app2_id)
if app2_obj:
session.delete(app2_obj)
session.commit()
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
"""Test that batch deletion only removes variables for the specified app."""
data = setup_test_data
app1_id = data["app1"].id
app2_id = data["app2"].id
# Verify initial state
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
with session_factory.create_session() as session:
app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_before == 5
assert app2_vars_before == 5
# Delete app1 variables
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
# Verify results
assert deleted_count == 5
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0 # All app1 variables deleted
assert app2_vars_after == 5 # App2 variables unchanged
with session_factory.create_session() as session:
app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0
assert app2_vars_after == 5
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
"""Test batch deletion with small batch size processes all records."""
data = setup_test_data
app1_id = data["app1"].id
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
assert deleted_count == 5
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert remaining_vars == 0
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
"""Test that deleting variables for nonexistent app returns 0."""
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
nonexistent_app_id = str(uuid.uuid4())
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
assert deleted_count == 0
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
"""Test that _delete_draft_variables wrapper function works correctly."""
data = setup_test_data
app1_id = data["app1"].id
# Verify initial state
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_before == 5
# Call wrapper function
deleted_count = _delete_draft_variables(app1_id)
# Verify results
assert deleted_count == 5
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_after == 0
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
"""Test batch deletion with larger dataset to verify batching logic."""
tenant, app = app_and_tenant
# Create many draft variables
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var)
variables.append(var)
variable_ids = [i.id for i in variables]
# Commit the variables to the database
db.session.commit()
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
try:
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
assert deleted_count == 25
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining_vars == 0
with session_factory.create_session() as session:
remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining == 0
finally:
query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.id.in_(variable_ids),
with session_factory.create_session() as session:
query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
.execution_options(synchronize_session=False)
)
db.session.execute(query)
session.execute(query)
session.commit()
class TestDeleteDraftVariablesWithOffloadIntegration:
"""Integration tests for draft variable deletion with Offload data."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with draft variables that have associated Offload files."""
tenant, app = app_and_tenant
# Create UploadFile records
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
db.session.add(upload_file1)
db.session.add(upload_file2)
db.session.flush()
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
# Create WorkflowDraftVariableFile records
from core.variables.types import SegmentType
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
db.session.add(var_file1)
db.session.add(var_file2)
db.session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
# Create WorkflowDraftVariable records with file associations
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
# Create a regular variable without Offload data
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
db.session.add(draft_var1)
db.session.add(draft_var2)
db.session.add(draft_var3)
db.session.commit()
yield data
yield {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
# Cleanup
db.session.rollback()
# Clean up any remaining records
for table, ids in [
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
db.session.execute(cleanup_query)
db.session.commit()
with session_factory.create_session() as session:
session.rollback()
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
"""Test that deleting draft variables also cleans up associated Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to succeed
mock_storage.delete.return_value = None
# Verify initial state
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
upload_files_before = db.session.query(UploadFile).count()
assert draft_vars_before == 3 # 2 with files + 1 regular
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count()
upload_files_before = session.query(UploadFile).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify results
assert deleted_count == 3
# Check that all draft variables are deleted
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Check that associated Offload data is cleaned up
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
assert var_files_after == 0
assert upload_files_after == 0
assert var_files_after == 0 # All variable files should be deleted
assert upload_files_after == 0 # All upload files should be deleted
# Verify storage deletion was called for both files
assert mock_storage.delete.call_count == 2
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
assert "test/file1.json" in storage_keys_deleted
@ -379,92 +327,71 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that database cleanup continues even when storage deletion fails."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to fail for first file, succeed for second
mock_storage.delete.side_effect = [Exception("Storage error"), None]
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify that all draft variables are still deleted
assert deleted_count == 3
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Database cleanup should still succeed even with storage errors
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage deletion was attempted for both files
assert mock_storage.delete.call_count == 2
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
"""Test deletion with mix of variables with and without Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Create additional app with only regular variables (no offload data)
tenant = data["tenant"]
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.flush()
# Add regular variables to app2
regular_vars = []
for i in range(3):
var = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
with session_factory.create_session() as session:
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(var)
regular_vars.append(var)
db.session.commit()
session.add(app2)
session.flush()
for i in range(3):
var = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
session.commit()
try:
# Mock storage deletion
mock_storage.delete.return_value = None
# Delete variables for app2 (no offload data)
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
assert deleted_count_app2 == 3
# Verify storage wasn't called for app2 (no offload files)
mock_storage.delete.assert_not_called()
# Delete variables for original app (with offload data)
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count_app1 == 3
# Now storage should be called for the offload files
assert mock_storage.delete.call_count == 2
finally:
# Cleanup app2 and its variables
cleanup_vars_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app2.id)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_vars_query)
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
with session_factory.create_session() as session:
cleanup_vars_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app2.id)
.execution_options(synchronize_session=False)
)
session.execute(cleanup_vars_query)
app2_obj = session.get(App, app2.id)
if app2_obj:
session.delete(app2_obj)
session.commit()

View File

@ -39,23 +39,22 @@ class TestCleanDatasetTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database before each test to ensure isolation."""
from extensions.ext_database import db
from extensions.ext_redis import redis_client
# Clear all test data
db.session.query(DatasetMetadataBinding).delete()
db.session.query(DatasetMetadata).delete()
db.session.query(AppDatasetJoin).delete()
db.session.query(DatasetQuery).delete()
db.session.query(DatasetProcessRule).delete()
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(UploadFile).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using the provided session fixture
db_session_with_containers.query(DatasetMetadataBinding).delete()
db_session_with_containers.query(DatasetMetadata).delete()
db_session_with_containers.query(AppDatasetJoin).delete()
db_session_with_containers.query(DatasetQuery).delete()
db_session_with_containers.query(DatasetProcessRule).delete()
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@ -103,10 +102,8 @@ class TestCleanDatasetTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -115,8 +112,8 @@ class TestCleanDatasetTask:
status="active",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account relationship
tenant_account_join = TenantAccountJoin(
@ -125,8 +122,8 @@ class TestCleanDatasetTask:
role=TenantAccountRole.OWNER,
)
db.session.add(tenant_account_join)
db.session.commit()
db_session_with_containers.add(tenant_account_join)
db_session_with_containers.commit()
return account, tenant
@ -155,10 +152,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@ -194,10 +189,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@ -232,10 +225,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segment
@ -267,10 +258,8 @@ class TestCleanDatasetTask:
used=False,
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
@ -302,31 +291,29 @@ class TestCleanDatasetTask:
)
# Verify results
from extensions.ext_database import db
# Check that dataset-related data was cleaned up
documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(documents) == 0
segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(segments) == 0
# Check that metadata and bindings were cleaned up
metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(metadata) == 0
bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
assert len(bindings) == 0
# Check that process rules and queries were cleaned up
process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
assert len(process_rules) == 0
queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
assert len(queries) == 0
# Check that app dataset joins were cleaned up
app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
assert len(app_joins) == 0
# Verify index processor was called
@ -378,9 +365,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Create dataset metadata and bindings
metadata = DatasetMetadata(
@ -403,11 +388,9 @@ class TestCleanDatasetTask:
binding.id = str(uuid.uuid4())
binding.created_at = datetime.now()
from extensions.ext_database import db
db.session.add(metadata)
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(metadata)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@ -421,22 +404,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that metadata and bindings were cleaned up
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify index processor was called
@ -489,12 +474,13 @@ class TestCleanDatasetTask:
mock_index_processor.clean.assert_called_once()
# Check that all data was cleaned up
from extensions.ext_database import db
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = (
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_segments) == 0
# Recreate data for next test case
@ -540,14 +526,13 @@ class TestCleanDatasetTask:
)
# Verify results - even with vector cleanup failure, documents and segments should be deleted
from extensions.ext_database import db
# Check that documents were still deleted despite vector cleanup failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite vector cleanup failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Verify that index processor was called and failed
@ -608,10 +593,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Mock the get_image_upload_file_ids function to return our image file IDs
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
@ -629,16 +612,18 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all image files were deleted from database
image_file_ids = [f.id for f in image_files]
remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
remaining_image_files = (
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
)
assert len(remaining_image_files) == 0
# Verify that storage.delete was called for each image file
@ -745,22 +730,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that all metadata and bindings were deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify performance expectations
@ -808,9 +795,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock storage to raise exceptions
mock_storage = mock_external_service_dependencies["storage"]
@ -827,18 +812,13 @@ class TestCleanDatasetTask:
)
# Verify results
# Check that documents were still deleted despite storage failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite storage failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Note: When storage operations fail, database deletions may be rolled back by implementation.
# This test focuses on ensuring the task handles the exception and continues execution/logging.
# Check that upload file was still deleted from database despite storage failure
# Note: When storage operations fail, the upload file may not be deleted
# This demonstrates that the cleanup process continues even with storage errors
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
# The upload file should still be deleted from the database even if storage cleanup fails
# However, this depends on the specific implementation of clean_dataset_task
if len(remaining_files) > 0:
@ -890,10 +870,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create document with special characters in name
special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~"
@ -912,8 +890,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Create segment with special characters and very long content
long_content = "Very long content " * 100 # Long content within reasonable limits
@ -934,8 +912,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Create upload file with special characters in name
special_filename = f"test_file_{special_content}.txt"
@ -952,14 +930,14 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
used=False,
)
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
# Update document with file reference
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
db_session_with_containers.commit()
# Save upload file ID for verification
upload_file_id = upload_file.id
@ -975,8 +953,8 @@ class TestCleanDatasetTask:
special_metadata.id = str(uuid.uuid4())
special_metadata.created_at = datetime.now()
db.session.add(special_metadata)
db.session.commit()
db_session_with_containers.add(special_metadata)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@ -990,19 +968,19 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
assert len(remaining_files) == 0
# Check that all metadata was deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
# Verify that storage.delete was called

View File

@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database and Redis before each test to ensure isolation."""
from extensions.ext_database import db
# Clear all test data
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using fixture session
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask:
status="normal",
plan="basic",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
)
# Mock global database session to simulate transaction issues
from extensions.ext_database import db
original_commit = db.session.commit
commit_called = False
def mock_commit():
nonlocal commit_called
if not commit_called:
commit_called = True
raise Exception("Database commit failed")
return original_commit()
db.session.commit = mock_commit
# Simulate an error during indexing to trigger rollback path
mock_processor = mock_external_service_dependencies["index_processor"]
mock_processor.load.side_effect = Exception("Simulated indexing error")
# Act: Execute the task
create_segment_to_index_task(segment.id)
@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask:
assert segment.disabled_at is not None
assert segment.error is not None
# Restore original commit method
db.session.commit = original_commit
def test_create_segment_to_index_metadata_validation(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask:
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
from extensions.ext_database import db
db.session.add(tenant)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Set the current tenant for the account
account.current_tenant = tenant
@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask:
built_in_field_enabled=False,
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask:
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en"
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask:
segments.append(segment)
from extensions.ext_database import db
for segment in segments:
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segments
@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask:
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Mock db.session.close to verify it's called
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify session was closed
mock_close.assert_called()
# Assert
assert result is None # Task should complete without returning a value
# Session lifecycle is managed by context manager; no explicit close assertion
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
"""

View File

@ -6,7 +6,6 @@ from faker import Faker
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import (
@ -75,15 +74,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -92,8 +91,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -105,8 +104,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -124,13 +123,13 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -157,15 +156,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -174,8 +173,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -187,8 +186,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -206,10 +205,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@ -219,7 +218,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -242,6 +241,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -250,7 +252,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -310,6 +312,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_document_indexing(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -317,7 +322,7 @@ class TestDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -353,6 +358,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -361,7 +369,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -400,7 +408,7 @@ class TestDocumentIndexingTasks:
indexing_status="completed", # Already completed
enabled=True,
)
db.session.add(doc1)
db_session_with_containers.add(doc1)
extra_documents.append(doc1)
# Document with disabled status
@ -417,10 +425,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=False, # Disabled
)
db.session.add(doc2)
db_session_with_containers.add(doc2)
extra_documents.append(doc2)
db.session.commit()
db_session_with_containers.commit()
all_documents = base_documents + extra_documents
document_ids = [doc.id for doc in all_documents]
@ -428,6 +436,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document states
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -435,7 +446,7 @@ class TestDocumentIndexingTasks:
# Verify all documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -482,20 +493,23 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error
@ -526,6 +540,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with billing disabled
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -533,7 +550,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -565,6 +582,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -573,7 +593,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -674,6 +694,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred (same as _document_indexing)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -681,7 +704,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated (same as _document_indexing)
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -794,6 +817,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -802,7 +828,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -865,6 +891,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function for tenant1 only
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred for tenant1
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()

View File

@ -4,7 +4,6 @@ import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.duplicate_document_indexing_task import (
@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks:
indexing_at=fake.date_time_this_year(),
created_by=dataset.created_by, # Add required field
)
db.session.add(segment)
db_session_with_containers.add(segment)
segments.append(segment)
db.session.commit()
db_session_with_containers.commit()
# Refresh to ensure all relationships are loaded
for document in documents:
db.session.refresh(document)
db_session_with_containers.refresh(document)
return dataset, documents, segments
@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks:
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
)
document_ids = [doc.id for doc in documents]
segment_ids = [seg.id for seg in segments]
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
# Verify index processor clean was called for each document with segments
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
# Verify segments were deleted from database
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
for segment in segments:
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
# Re-query segments from database using captured IDs to avoid stale ORM instances
for seg_id in segment_ids:
deleted_segment = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first()
)
assert deleted_segment is None
# Verify documents were updated to parsing status
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_duplicate_document_indexing_task(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _duplicate_document_indexing_task close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error.lower()
@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with documents that will exceed vector space limit
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "limit" in updated_document.error.lower()
@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")

View File

@ -49,10 +49,14 @@ def pipeline_id():
@pytest.fixture
def mock_db_session():
"""Mock database session with query capabilities."""
with patch("tasks.clean_dataset_task.db") as mock_db:
"""Mock database session via session_factory.create_session()."""
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
mock_session = MagicMock()
mock_db.session = mock_session
# context manager for create_session()
cm = MagicMock()
cm.__enter__.return_value = mock_session
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
# Setup query chain
mock_query = MagicMock()
@ -66,7 +70,10 @@ def mock_db_session():
# Setup execute for JOIN queries
mock_session.execute.return_value.all.return_value = []
yield mock_db
# Yield an object with a `.session` attribute to keep tests unchanged
wrapper = MagicMock()
wrapper.session = mock_session
yield wrapper
@pytest.fixture
@ -227,7 +234,9 @@ class TestBasicCleanup:
# Assert
mock_db_session.session.delete.assert_any_call(mock_document)
mock_db_session.session.delete.assert_any_call(mock_segment)
# Segments are deleted in batch; verify a DELETE on document_segments was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
mock_db_session.session.commit.assert_called_once()
def test_clean_dataset_task_deletes_related_records(
@ -413,7 +422,9 @@ class TestErrorHandling:
# Assert - documents and segments should still be deleted
mock_db_session.session.delete.assert_any_call(mock_document)
mock_db_session.session.delete.assert_any_call(mock_segment)
# Segments are deleted in batch; verify a DELETE on document_segments was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
mock_db_session.session.commit.assert_called_once()
def test_clean_dataset_task_storage_delete_failure_continues(
@ -461,7 +472,7 @@ class TestErrorHandling:
[mock_segment], # segments
]
mock_get_image_upload_file_ids.return_value = [image_file_id]
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
mock_storage.delete.side_effect = Exception("Storage service unavailable")
# Act
@ -476,8 +487,9 @@ class TestErrorHandling:
# Assert - storage delete was attempted for image file
mock_storage.delete.assert_called_with(mock_upload_file.key)
# Image file should still be deleted from database
mock_db_session.session.delete.assert_any_call(mock_upload_file)
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
def test_clean_dataset_task_database_error_rollback(
self,
@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup:
# Assert
mock_storage.delete.assert_called_with(mock_attachment_file.key)
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
mock_db_session.session.delete.assert_any_call(mock_binding)
# Attachment file and binding are deleted in batch; verify DELETEs were issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
def test_clean_dataset_task_attachment_storage_failure(
self,
@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup:
# Assert - storage delete was attempted
mock_storage.delete.assert_called_once()
# Records should still be deleted from database
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
mock_db_session.session.delete.assert_any_call(mock_binding)
# Records are deleted in batch; verify DELETEs were issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
# ============================================================================
@ -784,7 +799,7 @@ class TestUploadFileCleanup:
[mock_document], # documents
[], # segments
]
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
# Act
clean_dataset_task(
@ -798,7 +813,9 @@ class TestUploadFileCleanup:
# Assert
mock_storage.delete.assert_called_with(mock_upload_file.key)
mock_db_session.session.delete.assert_any_call(mock_upload_file)
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
def test_clean_dataset_task_handles_missing_upload_file(
self,
@ -832,7 +849,7 @@ class TestUploadFileCleanup:
[mock_document], # documents
[], # segments
]
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
# Act - should not raise exception
clean_dataset_task(
@ -949,11 +966,11 @@ class TestImageFileCleanup:
[mock_segment], # segments
]
# Setup a mock query chain that returns files in sequence
# Setup a mock query chain that returns files in batch (align with .in_().all())
mock_query = MagicMock()
mock_where = MagicMock()
mock_query.where.return_value = mock_where
mock_where.first.side_effect = mock_image_files
mock_where.all.return_value = mock_image_files
mock_db_session.session.query.return_value = mock_query
# Act
@ -966,10 +983,10 @@ class TestImageFileCleanup:
doc_form="paragraph_index",
)
# Assert
assert mock_storage.delete.call_count == 2
mock_storage.delete.assert_any_call("images/image-1.jpg")
mock_storage.delete.assert_any_call("images/image-2.jpg")
# Assert - each expected image key was deleted at least once
calls = [c.args[0] for c in mock_storage.delete.call_args_list]
assert "images/image-1.jpg" in calls
assert "images/image-2.jpg" in calls
def test_clean_dataset_task_handles_missing_image_file(
self,
@ -1010,7 +1027,7 @@ class TestImageFileCleanup:
]
# Image file not found
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
# Act - should not raise exception
clean_dataset_task(
@ -1086,14 +1103,15 @@ class TestEdgeCases:
doc_form="paragraph_index",
)
# Assert - all documents and segments should be deleted
# Assert - all documents and segments should be deleted (documents per-entity, segments in batch)
delete_calls = mock_db_session.session.delete.call_args_list
deleted_items = [call[0][0] for call in delete_calls]
for doc in mock_documents:
assert doc in deleted_items
for seg in mock_segments:
assert seg in deleted_items
# Verify a batch DELETE on document_segments occurred
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
def test_clean_dataset_task_document_with_empty_data_source_info(
self,

View File

@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests that expect session.close() to be called can observe it via the context manager
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@pytest.fixture

View File

@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task
@pytest.fixture
def mock_db_session():
"""Mock the db.session used in delete_account_task."""
with patch("tasks.delete_account_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
"""Mock session via session_factory.create_session()."""
with patch("tasks.delete_account_task.session_factory") as mock_sf:
session = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@pytest.fixture

View File

@ -109,13 +109,25 @@ def mock_document_segments(document_id):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# No session operations should be performed beyond the initial query
mock_db_session.close.assert_not_called()
# Session should still be closed via context manager teardown
assert mock_db_session.close.called
def test_successful_sync_when_page_updated(
self,
@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask:
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])

View File

@ -94,13 +94,25 @@ def mock_document_segments(document_ids):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Allow tests to observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test successful duplicate document indexing flow."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# Dataset via query.first()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# scalars() call sequence:
# 1) documents list
# 2..N) segments per document
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
# First call returns documents; subsequent calls return segments
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when billing limit is exceeded."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# First scalars() -> documents; subsequent -> empty segments
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_features = mock_feature_service.get_features.return_value
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.TEAM
@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when IndexingRunner raises an error."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when document is paused."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test that duplicate document indexing cleans old segments."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
# Act
@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore:
# Verify clean was called for each document
assert mock_processor.clean.call_count == len(mock_documents)
# Verify segments were deleted
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify segments were deleted in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# ============================================================================

View File

@ -11,21 +11,18 @@ from tasks.remove_app_and_related_data_task import (
class TestDeleteDraftVariablesBatch:
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup):
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
"""Test successful deletion of draft variables in batches."""
app_id = "test-app-id"
batch_size = 100
# Mock database connection and engine
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock two batches of results, then empty
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
@ -68,7 +65,7 @@ class TestDeleteDraftVariablesBatch:
select_result3.__iter__.return_value = iter([])
# Configure side effects in the correct order
mock_conn.execute.side_effect = [
mock_session.execute.side_effect = [
select_result1, # First SELECT
delete_result1, # First DELETE
select_result2, # Second SELECT
@ -86,54 +83,49 @@ class TestDeleteDraftVariablesBatch:
assert result == 150
# Verify database calls
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
# Verify offload cleanup was called for both batches with file_ids
expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)]
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
# Simplified verification - check that the right number of calls were made
# and that the SQL queries contain the expected patterns
actual_calls = mock_conn.execute.call_args_list
actual_calls = mock_session.execute.call_args_list
for i, actual_call in enumerate(actual_calls):
sql_text = str(actual_call[0][0])
normalized = " ".join(sql_text.split())
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
# Verify it's a SELECT query that now includes file_id
sql_text = str(actual_call[0][0])
assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text
assert "WHERE app_id = :app_id" in sql_text
assert "LIMIT :batch_size" in sql_text
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
assert "WHERE app_id = :app_id" in normalized
assert "LIMIT :batch_size" in normalized
else: # DELETE calls (odd indices: 1, 3)
# Verify it's a DELETE query
sql_text = str(actual_call[0][0])
assert "DELETE FROM workflow_draft_variables" in sql_text
assert "WHERE id IN :ids" in sql_text
assert "DELETE FROM workflow_draft_variables" in normalized
assert "WHERE id IN :ids" in normalized
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup):
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
"""Test deletion when no draft variables exist for the app."""
app_id = "nonexistent-app-id"
batch_size = 1000
# Mock database connection
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock empty result
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.return_value = empty_result
mock_session.execute.return_value = empty_result
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 0
assert mock_conn.execute.call_count == 1 # Only one select query
assert mock_session.execute.call_count == 1 # Only one select query
mock_offload_cleanup.assert_not_called() # No files to clean up
def test_delete_draft_variables_batch_invalid_batch_size(self):
@ -147,22 +139,19 @@ class TestDeleteDraftVariablesBatch:
delete_draft_variables_batch(app_id, 0)
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
@patch("tasks.remove_app_and_related_data_task.session_factory")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup):
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
"""Test that batch deletion logs progress correctly."""
app_id = "test-app-id"
batch_size = 50
# Mock database
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock one batch then empty
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
@ -183,7 +172,7 @@ class TestDeleteDraftVariablesBatch:
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.side_effect = [
mock_session.execute.side_effect = [
# Select query result
select_result,
# Delete query result
@ -201,7 +190,7 @@ class TestDeleteDraftVariablesBatch:
# Verify offload cleanup was called with file_ids
if batch_file_ids:
mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids)
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
# Verify logging calls
assert mock_logging.info.call_count == 2
@ -261,19 +250,19 @@ class TestDeleteDraftVariableOffloadData:
actual_calls = mock_conn.execute.call_args_list
# First call should be the SELECT query
select_call_sql = str(actual_calls[0][0][0])
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
# Second call should be DELETE upload_files
delete_upload_call_sql = str(actual_calls[1][0][0])
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
assert "DELETE FROM upload_files" in delete_upload_call_sql
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
# Third call should be DELETE workflow_draft_variable_files
delete_variable_files_call_sql = str(actual_calls[2][0][0])
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql