mirror of https://github.com/langgenius/dify.git
refactor: use session factory instead of call db.session directly (#31198)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
071bbc6d74
commit
121d301a41
|
|
@ -3,8 +3,8 @@ from datetime import UTC, datetime
|
|||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
|
|
@ -31,13 +31,11 @@ class TriggerPostLayer(GraphEngineLayer):
|
|||
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
|
||||
start_time: datetime,
|
||||
trigger_log_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
):
|
||||
super().__init__()
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
self.session_maker = session_maker
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
|
@ -47,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer):
|
|||
Update trigger log with success or failure.
|
||||
"""
|
||||
if isinstance(event, tuple(self._STATUS_MAP.keys())):
|
||||
with self.session_maker() as session:
|
||||
with session_factory.create_session() as session:
|
||||
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = repo.get_by_id(self.trigger_log_id)
|
||||
if not trigger_log:
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ from extensions.ext_database import db
|
|||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -473,6 +472,9 @@ class TraceTask:
|
|||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
# Lazy import to avoid circular import during module initialization
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return cls._workflow_run_repo
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import time
|
|||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DatasetAutoDisableLog, DocumentSegment
|
||||
|
|
@ -28,106 +28,106 @@ def add_document_to_index_task(dataset_document_id: str):
|
|||
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
with session_factory.create_session() as session:
|
||||
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if dataset_document.indexing_status != "completed":
|
||||
db.session.close()
|
||||
return
|
||||
if dataset_document.indexing_status != "completed":
|
||||
return
|
||||
|
||||
indexing_cache_key = f"document_{dataset_document.id}_indexing"
|
||||
indexing_cache_key = f"document_{dataset_document.id}_indexing"
|
||||
|
||||
try:
|
||||
dataset = dataset_document.dataset
|
||||
if not dataset:
|
||||
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
|
||||
try:
|
||||
dataset = dataset_document.dataset
|
||||
if not dataset:
|
||||
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
|
||||
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
# delete auto disable log
|
||||
session.query(DatasetAutoDisableLog).where(
|
||||
DatasetAutoDisableLog.document_id == dataset_document.id
|
||||
).delete()
|
||||
|
||||
# update segment to enable
|
||||
session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
|
||||
{
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.disabled_at: None,
|
||||
DocumentSegment.disabled_by: None,
|
||||
DocumentSegment.updated_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
session.commit()
|
||||
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
# delete auto disable log
|
||||
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
|
||||
|
||||
# update segment to enable
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
|
||||
{
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.disabled_at: None,
|
||||
DocumentSegment.disabled_by: None,
|
||||
DocumentSegment.updated_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("add document to index failed")
|
||||
dataset_document.enabled = False
|
||||
dataset_document.disabled_at = naive_utc_now()
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("add document to index failed")
|
||||
dataset_document.enabled = False
|
||||
dataset_document.disabled_at = naive_utc_now()
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import click
|
|||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
|
@ -32,74 +32,72 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
|
|||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
active_jobs_key = f"annotation_import_active:{tenant_id}"
|
||||
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
with session_factory.create_session() as session:
|
||||
# get app info
|
||||
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
|
||||
if app:
|
||||
try:
|
||||
documents = []
|
||||
for content in content_list:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
|
||||
if app:
|
||||
try:
|
||||
documents = []
|
||||
for content in content_list:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
|
||||
)
|
||||
session.add(annotation)
|
||||
session.flush()
|
||||
|
||||
document = Document(
|
||||
page_content=content["question"],
|
||||
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
# if annotation reply is enabled , batch add annotations' index
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.flush()
|
||||
|
||||
document = Document(
|
||||
page_content=content["question"],
|
||||
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
# if annotation reply is enabled , batch add annotations' index
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if app_annotation_setting:
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
app_annotation_setting.collection_binding_id, "annotation"
|
||||
)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise NotFound("App annotation setting not found")
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
app_annotation_setting.collection_binding_id, "annotation"
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.create(documents, duplicate_check=True)
|
||||
|
||||
session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Build index successful for batch import annotation: {} latency: {}".format(
|
||||
job_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise NotFound("App annotation setting not found")
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=dataset_collection_binding.provider_name,
|
||||
embedding_model=dataset_collection_binding.model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.create(documents, duplicate_check=True)
|
||||
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Build index successful for batch import annotation: {} latency: {}".format(
|
||||
job_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
|
||||
redis_client.setex(indexing_error_msg_key, 600, str(e))
|
||||
logger.exception("Build index for batch import annotations failed")
|
||||
finally:
|
||||
# Clean up active job tracking to release concurrency slot
|
||||
try:
|
||||
redis_client.zrem(active_jobs_key, job_id)
|
||||
logger.debug("Released concurrency slot for job: %s", job_id)
|
||||
except Exception as cleanup_error:
|
||||
# Log but don't fail if cleanup fails - the job will be auto-expired
|
||||
logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
|
||||
|
||||
# Close database session
|
||||
db.session.close()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
|
||||
redis_client.setex(indexing_error_msg_key, 600, str(e))
|
||||
logger.exception("Build index for batch import annotations failed")
|
||||
finally:
|
||||
# Clean up active job tracking to release concurrency slot
|
||||
try:
|
||||
redis_client.zrem(active_jobs_key, job_id)
|
||||
logger.debug("Released concurrency slot for job: %s", job_id)
|
||||
except Exception as cleanup_error:
|
||||
# Log but don't fail if cleanup fails - the job will be auto-expired
|
||||
logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import click
|
|||
from celery import shared_task
|
||||
from sqlalchemy import exists, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
|
@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
|
|||
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
with session_factory.create_session() as session:
|
||||
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
return
|
||||
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
|
||||
if not app_annotation_setting:
|
||||
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
|
||||
|
||||
try:
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
collection_binding_id=app_annotation_setting.collection_binding_id,
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
|
||||
return
|
||||
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
|
||||
|
||||
try:
|
||||
if annotations_exists:
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete()
|
||||
except Exception:
|
||||
logger.exception("Delete annotation index failed when annotation deleted.")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
collection_binding_id=app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
||||
# delete annotation setting
|
||||
db.session.delete(app_annotation_setting)
|
||||
db.session.commit()
|
||||
try:
|
||||
if annotations_exists:
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
vector.delete()
|
||||
except Exception:
|
||||
logger.exception("Delete annotation index failed when annotation deleted.")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("Annotation batch deleted index failed")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, "error")
|
||||
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
|
||||
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
|
||||
finally:
|
||||
redis_client.delete(disable_app_annotation_key)
|
||||
db.session.close()
|
||||
# delete annotation setting
|
||||
session.delete(app_annotation_setting)
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"App annotations index deleted : {app_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Annotation batch deleted index failed")
|
||||
redis_client.setex(disable_app_annotation_job_key, 600, "error")
|
||||
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
|
||||
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
|
||||
finally:
|
||||
redis_client.delete(disable_app_annotation_key)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import click
|
|||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset
|
||||
|
|
@ -33,92 +33,98 @@ def enable_annotation_reply_task(
|
|||
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
# get app info
|
||||
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
with session_factory.create_session() as session:
|
||||
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
|
||||
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
if not app:
|
||||
logger.info(click.style(f"App not found: {app_id}", fg="red"))
|
||||
return
|
||||
|
||||
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
|
||||
try:
|
||||
documents = []
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
)
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
if annotation_setting:
|
||||
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
|
||||
old_dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
annotation_setting.collection_binding_id, "annotation"
|
||||
)
|
||||
)
|
||||
if old_dataset_collection_binding and annotations:
|
||||
old_dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=old_dataset_collection_binding.provider_name,
|
||||
embedding_model=old_dataset_collection_binding.model_name,
|
||||
collection_binding_id=old_dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
try:
|
||||
old_vector.delete()
|
||||
except Exception as e:
|
||||
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
|
||||
annotation_setting.score_threshold = score_threshold
|
||||
annotation_setting.collection_binding_id = dataset_collection_binding.id
|
||||
annotation_setting.updated_user_id = user_id
|
||||
annotation_setting.updated_at = naive_utc_now()
|
||||
db.session.add(annotation_setting)
|
||||
else:
|
||||
new_app_annotation_setting = AppAnnotationSetting(
|
||||
app_id=app_id,
|
||||
score_threshold=score_threshold,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
created_user_id=user_id,
|
||||
updated_user_id=user_id,
|
||||
try:
|
||||
documents = []
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
)
|
||||
db.session.add(new_app_annotation_setting)
|
||||
annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
|
||||
old_dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
|
||||
annotation_setting.collection_binding_id, "annotation"
|
||||
)
|
||||
)
|
||||
if old_dataset_collection_binding and annotations:
|
||||
old_dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=old_dataset_collection_binding.provider_name,
|
||||
embedding_model=old_dataset_collection_binding.model_name,
|
||||
collection_binding_id=old_dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question_text,
|
||||
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
|
||||
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
try:
|
||||
old_vector.delete()
|
||||
except Exception as e:
|
||||
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
|
||||
annotation_setting.score_threshold = score_threshold
|
||||
annotation_setting.collection_binding_id = dataset_collection_binding.id
|
||||
annotation_setting.updated_user_id = user_id
|
||||
annotation_setting.updated_at = naive_utc_now()
|
||||
session.add(annotation_setting)
|
||||
else:
|
||||
new_app_annotation_setting = AppAnnotationSetting(
|
||||
app_id=app_id,
|
||||
score_threshold=score_threshold,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
created_user_id=user_id,
|
||||
updated_user_id=user_id,
|
||||
)
|
||||
documents.append(document)
|
||||
session.add(new_app_annotation_setting)
|
||||
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
try:
|
||||
vector.delete_by_metadata_field("app_id", app_id)
|
||||
except Exception as e:
|
||||
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
|
||||
vector.create(documents)
|
||||
db.session.commit()
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("Annotation batch created index failed")
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, "error")
|
||||
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
|
||||
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
|
||||
db.session.rollback()
|
||||
finally:
|
||||
redis_client.delete(enable_app_annotation_key)
|
||||
db.session.close()
|
||||
dataset = Dataset(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
if annotations:
|
||||
for annotation in annotations:
|
||||
document = Document(
|
||||
page_content=annotation.question_text,
|
||||
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
|
||||
)
|
||||
documents.append(document)
|
||||
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
try:
|
||||
vector.delete_by_metadata_field("app_id", app_id)
|
||||
except Exception as e:
|
||||
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
|
||||
vector.create(documents)
|
||||
session.commit()
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"App annotations added to index: {app_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Annotation batch created index failed")
|
||||
redis_client.setex(enable_app_annotation_job_key, 600, "error")
|
||||
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
|
||||
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
|
||||
session.rollback()
|
||||
finally:
|
||||
redis_client.delete(enable_app_annotation_key)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ from typing import Any
|
|||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.layers.trigger_post_layer import TriggerPostLayer
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser, Tenant
|
||||
|
|
@ -98,10 +98,7 @@ def _execute_workflow_common(
|
|||
):
|
||||
"""Execute workflow with common logic and trigger log updates."""
|
||||
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
with session_factory.create_session() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
|
||||
# Get trigger log
|
||||
|
|
@ -157,7 +154,7 @@ def _execute_workflow_common(
|
|||
root_node_id=trigger_data.root_node_id,
|
||||
graph_engine_layers=[
|
||||
# TODO: Re-enable TimeSliceLayer after the HITL release.
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
|
||||
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
|
|
@ -28,65 +28,64 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||
"""
|
||||
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
|
||||
try:
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
db.session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
||||
for image_file in image_files:
|
||||
try:
|
||||
if image_file and image_file.key:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(stmt)
|
||||
session.delete(segment)
|
||||
if file_ids:
|
||||
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
try:
|
||||
if image_file and image_file.key:
|
||||
storage.delete(image_file.key)
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
upload_file_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
||||
session.execute(stmt)
|
||||
|
||||
db.session.commit()
|
||||
if file_ids:
|
||||
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
|
||||
db.session.delete(file)
|
||||
session.commit()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned documents when documents deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
except Exception:
|
||||
logger.exception("Cleaned documents when documents deleted failed")
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ import pandas as pd
|
|||
from celery import shared_task
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from libs import helper
|
||||
|
|
@ -48,104 +48,107 @@ def batch_create_segment_to_index_task(
|
|||
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
try:
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
|
||||
dataset_document = db.session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
raise ValueError("Document is not available.")
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
raise ValueError("Document is not available.")
|
||||
|
||||
upload_file = db.session.get(UploadFile, upload_file_id)
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
upload_file = session.get(UploadFile, upload_file_id)
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
)
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
session.add(dataset_document)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
db.session.add(dataset_document)
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
db.session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
finally:
|
||||
db.session.close()
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import WorkflowType
|
||||
from models.dataset import (
|
||||
|
|
@ -53,135 +53,155 @@ def clean_dataset_task(
|
|||
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = Dataset(
|
||||
id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct,
|
||||
collection_binding_id=collection_binding_id,
|
||||
)
|
||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
||||
# Use JOIN to fetch attachments with bindings in a single query
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
|
||||
).all()
|
||||
|
||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||
# This ensures all invalid doc_form values are properly handled
|
||||
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
|
||||
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
|
||||
doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
logger.info(
|
||||
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
|
||||
)
|
||||
|
||||
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
|
||||
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
|
||||
# Continue with document and segment deletion even if vector cleanup fails
|
||||
logger.info(
|
||||
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
|
||||
dataset = Dataset(
|
||||
id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct,
|
||||
collection_binding_id=collection_binding_id,
|
||||
)
|
||||
documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
||||
# Use JOIN to fetch attachments with bindings in a single query
|
||||
attachments_with_bindings = session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.tenant_id == tenant_id,
|
||||
SegmentAttachmentBinding.dataset_id == dataset_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
if documents is None or len(documents) == 0:
|
||||
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
|
||||
else:
|
||||
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
|
||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||
# This ensures all invalid doc_form values are properly handled
|
||||
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
|
||||
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
# delete document file
|
||||
doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Invalid doc_form detected, using default index type for cleanup: {doc_form}",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if image_file is None:
|
||||
continue
|
||||
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
|
||||
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
|
||||
# Continue with document and segment deletion even if vector cleanup fails
|
||||
logger.info(
|
||||
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
|
||||
)
|
||||
|
||||
if documents is None or len(documents) == 0:
|
||||
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
|
||||
else:
|
||||
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
|
||||
|
||||
for document in documents:
|
||||
session.delete(document)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(stmt)
|
||||
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
upload_file_id,
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
db.session.delete(attachment_file)
|
||||
db.session.delete(binding)
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||
# delete dataset metadata
|
||||
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
# delete pipeline and workflow
|
||||
if pipeline_id:
|
||||
db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
|
||||
db.session.query(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
).delete()
|
||||
# delete files
|
||||
if documents:
|
||||
for document in documents:
|
||||
try:
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
|
||||
# delete dataset metadata
|
||||
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
|
||||
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
|
||||
# delete pipeline and workflow
|
||||
if pipeline_id:
|
||||
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
|
||||
session.query(Workflow).where(
|
||||
Workflow.tenant_id == tenant_id,
|
||||
Workflow.app_id == pipeline_id,
|
||||
Workflow.type == WorkflowType.RAG_PIPELINE,
|
||||
).delete()
|
||||
# delete files
|
||||
if documents:
|
||||
file_ids = []
|
||||
for document in documents:
|
||||
if document.data_source_type == "upload_file":
|
||||
if document.data_source_info:
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
continue
|
||||
storage.delete(file.key)
|
||||
db.session.delete(file)
|
||||
except Exception:
|
||||
continue
|
||||
file_ids.append(file_id)
|
||||
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
|
||||
for file in files:
|
||||
storage.delete(file.key)
|
||||
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
# Add rollback to prevent dirty session state in case of exceptions
|
||||
# This ensures the database session is properly cleaned up
|
||||
try:
|
||||
db.session.rollback()
|
||||
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
|
||||
file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
||||
session.execute(file_delete_stmt)
|
||||
|
||||
session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to rollback database session")
|
||||
# Add rollback to prevent dirty session state in case of exceptions
|
||||
# This ensures the database session is properly cleaned up
|
||||
try:
|
||||
session.rollback()
|
||||
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Failed to rollback database session")
|
||||
|
||||
logger.exception("Cleaned dataset when dataset deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
logger.exception("Cleaned dataset when dataset deleted failed")
|
||||
finally:
|
||||
# Explicitly close the session for test expectations and safety
|
||||
try:
|
||||
session.close()
|
||||
except Exception:
|
||||
logger.exception("Failed to close database session")
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
|
|
@ -29,85 +29,94 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
# Use JOIN to fetch attachments with bindings in a single query
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
|
||||
SegmentAttachmentBinding.dataset_id == dataset_id,
|
||||
SegmentAttachmentBinding.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
# Use JOIN to fetch attachments with bindings in a single query
|
||||
attachments_with_bindings = session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
|
||||
SegmentAttachmentBinding.dataset_id == dataset_id,
|
||||
SegmentAttachmentBinding.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
if image_file is None:
|
||||
continue
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
session.delete(segment)
|
||||
|
||||
session.commit()
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
upload_file_id,
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
db.session.delete(image_file)
|
||||
db.session.delete(segment)
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
db.session.commit()
|
||||
if file_id:
|
||||
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
db.session.delete(file)
|
||||
db.session.commit()
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
db.session.delete(attachment_file)
|
||||
db.session.delete(binding)
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
# delete dataset metadata binding
|
||||
db.session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
db.session.commit()
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -24,37 +24,37 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
|||
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
for document_id in document_ids:
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
db.session.delete(document)
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
|
||||
session.execute(document_delete_stmt)
|
||||
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
for document_id in document_ids:
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||
dataset_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||
dataset_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when import form notion document deleted failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when import form notion document deleted failed")
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ import time
|
|||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
|
|
@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
|
|||
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "waiting":
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
try:
|
||||
# update segment status to indexing
|
||||
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
dataset = segment.dataset
|
||||
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
if segment.status != "waiting":
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, [document])
|
||||
try:
|
||||
# update segment status to indexing
|
||||
session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "indexing",
|
||||
DocumentSegment.indexing_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
session.commit()
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
# update segment to completed
|
||||
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
dataset = segment.dataset
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("create segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.load(dataset, [document])
|
||||
|
||||
# update segment to completed
|
||||
session.query(DocumentSegment).filter_by(id=segment.id).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("create segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import time
|
|||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
|
@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
"Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Deal dataset vector index failed")
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ import click
|
|||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
|
@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
|||
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
elif action == "add":
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
elif action == "add":
|
||||
dataset_documents = session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Deal dataset vector index failed")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
from tasks.mail_account_deletion_task import send_deletion_success_task
|
||||
|
|
@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@shared_task(queue="dataset")
|
||||
def delete_account_task(account_id):
|
||||
account = db.session.query(Account).where(Account.id == account_id).first()
|
||||
try:
|
||||
if dify_config.BILLING_ENABLED:
|
||||
BillingService.delete_account(account_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete account %s from billing service.", account_id)
|
||||
raise
|
||||
with session_factory.create_session() as session:
|
||||
account = session.query(Account).where(Account.id == account_id).first()
|
||||
try:
|
||||
if dify_config.BILLING_ENABLED:
|
||||
BillingService.delete_account(account_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete account %s from billing service.", account_id)
|
||||
raise
|
||||
|
||||
if not account:
|
||||
logger.error("Account %s not found.", account_id)
|
||||
return
|
||||
# send success email
|
||||
send_deletion_success_task.delay(account.email)
|
||||
if not account:
|
||||
logger.error("Account %s not found.", account_id)
|
||||
return
|
||||
# send success email
|
||||
send_deletion_success_task.delay(account.email)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models import ConversationVariable
|
||||
from models.model import Message, MessageAnnotation, MessageFeedback
|
||||
from models.tools import ToolConversationVariables, ToolFile
|
||||
|
|
@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str):
|
|||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.query(ToolConversationVariables).where(
|
||||
ToolConversationVariables.conversation_id == conversation_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
|
||||
db.session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.query(ToolConversationVariables).where(
|
||||
ToolConversationVariables.conversation_id == conversation_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
(
|
||||
f"Succeeded cleaning data from db for conversation_id {conversation_id} "
|
||||
f"latency: {end_at - start_at}"
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
|
||||
session.rollback()
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import time
|
|||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
|
||||
|
|
@ -26,49 +26,52 @@ def delete_segment_from_index_task(
|
|||
"""
|
||||
logger.info(click.style("Start delete segment from index", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
|
||||
return
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if not dataset_document:
|
||||
return
|
||||
dataset_document = session.query(Document).where(Document.id == document_id).first()
|
||||
if not dataset_document:
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logging.info("Document not in valid state for index operations, skipping")
|
||||
return
|
||||
doc_form = dataset_document.doc_form
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
logging.info("Document not in valid state for index operations, skipping")
|
||||
return
|
||||
doc_form = dataset_document.doc_form
|
||||
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
index_node_ids,
|
||||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
segment_attachment_bindings = (
|
||||
db.session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
index_node_ids,
|
||||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
)
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||
for binding in segment_attachment_bindings:
|
||||
db.session.delete(binding)
|
||||
# delete upload file
|
||||
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||
db.session.commit()
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
segment_attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
)
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||
for binding in segment_attachment_bindings:
|
||||
session.delete(binding)
|
||||
# delete upload file
|
||||
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("delete segment from index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("delete segment from index failed")
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import time
|
|||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
|
@ -23,46 +23,53 @@ def disable_segment_from_index_task(segment_id: str):
|
|||
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
try:
|
||||
dataset = segment.dataset
|
||||
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
if segment.status != "completed":
|
||||
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
try:
|
||||
dataset = segment.dataset
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("remove segment from index failed")
|
||||
segment.enabled = True
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment removed from index: {segment.id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("remove segment from index failed")
|
||||
segment.enabled = True
|
||||
session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import click
|
|||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
|
@ -26,69 +26,65 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
|||
"""
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments:
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if dataset.is_multimodal:
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_attachment_bindings = (
|
||||
db.session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_node_ids.extend(attachment_ids)
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
).all()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
# update segment error msg
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
).update(
|
||||
{
|
||||
"disabled_at": None,
|
||||
"disabled_by": None,
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
finally:
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
if not segments:
|
||||
return
|
||||
|
||||
try:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if dataset.is_multimodal:
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
)
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_node_ids.extend(attachment_ids)
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
# update segment error msg
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
).update(
|
||||
{
|
||||
"disabled_at": None,
|
||||
"disabled_by": None,
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
session.commit()
|
||||
finally:
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.delete(indexing_cache_key)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
|
@ -28,105 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion page found")
|
||||
workspace_id = data_source_info["notion_workspace_id"]
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
document.tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
document.indexing_status = "error"
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=document.tenant_id,
|
||||
)
|
||||
data_source_info = document.data_source_info_dict
|
||||
if document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion page found")
|
||||
workspace_id = data_source_info["notion_workspace_id"]
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
|
||||
last_edited_time = loader.get_notion_last_edited_time()
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
# check the page is updated
|
||||
if last_edited_time != page_edited_time:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
document.tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
document.indexing_status = "error"
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.commit()
|
||||
return
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=document.tenant_id,
|
||||
)
|
||||
|
||||
last_edited_time = loader.get_notion_last_edited_time()
|
||||
|
||||
# check the page is updated
|
||||
if last_edited_time != page_edited_time:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ import click
|
|||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -46,66 +46,63 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
|
||||
db.session.close()
|
||||
return
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
|
||||
return
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
except Exception as e:
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
return
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||
|
|
|
|||
|
|
@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.document_task import DocumentTask
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue(
|
|||
|
||||
|
||||
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
|
||||
documents = []
|
||||
documents: list[Document] = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
current = int(getattr(vector_space, "size", 0) or 0)
|
||||
limit = int(getattr(vector_space, "limit", 0) or 0)
|
||||
if limit > 0 and (current + count) > limit:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have exceeded the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
current = int(getattr(vector_space, "size", 0) or 0)
|
||||
limit = int(getattr(vector_space, "limit", 0) or 0)
|
||||
if limit > 0 and (current + count) > limit:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have exceeded the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
documents = list(
|
||||
session.scalars(
|
||||
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
).all()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
documents = list(
|
||||
session.scalars(
|
||||
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
).all()
|
||||
)
|
||||
|
||||
if document:
|
||||
for document in documents:
|
||||
logger.info(click.style(f"Start process document: {document.id}", fg="green"))
|
||||
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
|
@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
|
|||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(list(documents))
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import time
|
|||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import DocumentSegment
|
||||
|
|
@ -27,91 +27,93 @@ def enable_segment_to_index_task(segment_id: str):
|
|||
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
try:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
dataset = segment.dataset
|
||||
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
|
||||
if not segment:
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
if segment.status != "completed":
|
||||
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
try:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
dataset = segment.dataset
|
||||
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = segment.document
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
multimodel_documents = []
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodel_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
multimodel_documents = []
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodel_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# save vector index
|
||||
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
|
||||
# save vector index
|
||||
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("enable segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("enable segment to index failed")
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.status = "error"
|
||||
segment.error = str(e)
|
||||
session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ import click
|
|||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
|
@ -29,105 +29,102 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
|||
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
|
||||
return
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
if not dataset_document:
|
||||
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
|
||||
return
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
try:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": document_id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
|
||||
return
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": document_id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
try:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": document_id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": document_id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
if dataset.is_multimodal:
|
||||
for attachment in segment.attachments:
|
||||
multimodal_documents.append(
|
||||
AttachmentDocument(
|
||||
page_content=attachment["name"],
|
||||
metadata={
|
||||
"doc_id": attachment["id"],
|
||||
"doc_hash": "",
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
"doc_type": DocType.IMAGE,
|
||||
},
|
||||
)
|
||||
)
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("enable segments to index failed")
|
||||
# update segment error msg
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
).update(
|
||||
{
|
||||
"error": str(e),
|
||||
"status": "error",
|
||||
"disabled_at": naive_utc_now(),
|
||||
"enabled": False,
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
finally:
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception("enable segments to index failed")
|
||||
# update segment error msg
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
).update(
|
||||
{
|
||||
"error": str(e),
|
||||
"status": "error",
|
||||
"disabled_at": naive_utc_now(),
|
||||
"enabled": False,
|
||||
}
|
||||
)
|
||||
session.commit()
|
||||
finally:
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
redis_client.delete(indexing_cache_key)
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import time
|
|||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
|
|||
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
|
||||
indexing_runner.run([document])
|
||||
elif document.indexing_status == "splitting":
|
||||
indexing_runner.run_in_splitting_status(document)
|
||||
elif document.indexing_status == "indexing":
|
||||
indexing_runner.run_in_indexing_status(document)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
|
||||
indexing_runner.run([document])
|
||||
elif document.indexing_status == "splitting":
|
||||
indexing_runner.run_in_splitting_status(document)
|
||||
elif document.indexing_status == "indexing":
|
||||
indexing_runner.run_in_indexing_status(document)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from models import (
|
||||
ApiToken,
|
||||
|
|
@ -77,7 +80,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
|||
_delete_workflow_webhook_triggers(tenant_id, app_id)
|
||||
_delete_workflow_schedule_plans(tenant_id, app_id)
|
||||
_delete_workflow_trigger_logs(tenant_id, app_id)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except SQLAlchemyError as e:
|
||||
|
|
@ -89,8 +91,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||
def del_model_config(model_config_id: str):
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
||||
def del_model_config(session, model_config_id: str):
|
||||
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_model_configs where app_id=:app_id limit 1000""",
|
||||
|
|
@ -101,8 +103,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_site(tenant_id: str, app_id: str):
|
||||
def del_site(site_id: str):
|
||||
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||
def del_site(session, site_id: str):
|
||||
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from sites where app_id=:app_id limit 1000""",
|
||||
|
|
@ -113,8 +115,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def del_mcp_server(mcp_server_id: str):
|
||||
db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||
def del_mcp_server(session, mcp_server_id: str):
|
||||
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
|
||||
|
|
@ -125,8 +127,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(api_token_id: str):
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
def del_api_token(session, api_token_id: str):
|
||||
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from api_tokens where app_id=:app_id limit 1000""",
|
||||
|
|
@ -137,8 +139,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||
def del_installed_app(installed_app_id: str):
|
||||
db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
||||
def del_installed_app(session, installed_app_id: str):
|
||||
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
|
@ -149,10 +151,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||
def del_recommended_app(recommended_app_id: str):
|
||||
db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
def del_recommended_app(session, recommended_app_id: str):
|
||||
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from recommended_apps where app_id=:app_id limit 1000""",
|
||||
|
|
@ -163,8 +163,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
def del_annotation_hit_history(annotation_hit_history_id: str):
|
||||
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
||||
def del_annotation_hit_history(session, annotation_hit_history_id: str):
|
||||
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
|
|
@ -175,8 +175,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
|||
"annotation hit history",
|
||||
)
|
||||
|
||||
def del_annotation_setting(annotation_setting_id: str):
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
||||
def del_annotation_setting(session, annotation_setting_id: str):
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
|
|
@ -189,8 +189,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||
def del_dataset_join(dataset_join_id: str):
|
||||
db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
||||
def del_dataset_join(session, dataset_join_id: str):
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
|
||||
|
|
@ -201,8 +201,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_workflows(tenant_id: str, app_id: str):
|
||||
def del_workflow(workflow_id: str):
|
||||
db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
||||
def del_workflow(session, workflow_id: str):
|
||||
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
|
@ -241,10 +241,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(workflow_app_log_id: str):
|
||||
db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
def del_workflow_app_log(session, workflow_app_log_id: str):
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
|
@ -255,11 +253,11 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||
def del_conversation(conversation_id: str):
|
||||
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
def del_conversation(session, conversation_id: str):
|
||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
|
||||
session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from conversations where app_id=:app_id limit 1000""",
|
||||
|
|
@ -270,28 +268,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_conversation_variables(*, app_id: str):
|
||||
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
|
||||
with db.engine.connect() as conn:
|
||||
conn.execute(stmt)
|
||||
conn.commit()
|
||||
with session_factory.create_session() as session:
|
||||
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
|
||||
|
||||
|
||||
def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
def del_message(message_id: str):
|
||||
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
|
||||
def del_message(session, message_id: str):
|
||||
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
||||
session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
|
||||
db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
|
||||
db.session.query(Message).where(Message.id == message_id).delete()
|
||||
session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(Message).where(Message.id == message_id).delete()
|
||||
|
||||
_delete_records(
|
||||
"""select id from messages where app_id=:app_id limit 1000""",
|
||||
|
|
@ -302,8 +298,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||
def del_tool_provider(tool_provider_id: str):
|
||||
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
||||
def del_tool_provider(session, tool_provider_id: str):
|
||||
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
|
|
@ -316,8 +312,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||
def del_tag_binding(tag_binding_id: str):
|
||||
db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
||||
def del_tag_binding(session, tag_binding_id: str):
|
||||
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
|
||||
|
|
@ -328,8 +324,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_end_users(tenant_id: str, app_id: str):
|
||||
def del_end_user(end_user_id: str):
|
||||
db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
||||
def del_end_user(session, end_user_id: str):
|
||||
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
|
@ -340,10 +336,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
||||
def del_trace_app_config(trace_app_config_id: str):
|
||||
db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
def del_trace_app_config(session, trace_app_config_id: str):
|
||||
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from trace_app_config where app_id=:app_id limit 1000""",
|
||||
|
|
@ -381,14 +375,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
|||
total_files_deleted = 0
|
||||
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
with session_factory.create_session() as session:
|
||||
# Get a batch of draft variable IDs along with their file_ids
|
||||
query_sql = """
|
||||
SELECT id, file_id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
|
||||
result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
|
||||
|
||||
rows = list(result)
|
||||
if not rows:
|
||||
|
|
@ -399,7 +393,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
|||
|
||||
# Clean up associated Offload data first
|
||||
if file_ids:
|
||||
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
|
||||
files_deleted = _delete_draft_variable_offload_data(session, file_ids)
|
||||
total_files_deleted += files_deleted
|
||||
|
||||
# Delete the draft variables
|
||||
|
|
@ -407,8 +401,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
|||
DELETE FROM workflow_draft_variables
|
||||
WHERE id IN :ids
|
||||
"""
|
||||
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
|
||||
batch_deleted = deleted_result.rowcount
|
||||
deleted_result = cast(
|
||||
CursorResult[Any],
|
||||
session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}),
|
||||
)
|
||||
batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0)
|
||||
total_deleted += batch_deleted
|
||||
|
||||
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
|
||||
|
|
@ -423,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
|||
return total_deleted
|
||||
|
||||
|
||||
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
||||
def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
|
||||
"""
|
||||
Delete Offload data associated with WorkflowDraftVariable file_ids.
|
||||
|
||||
|
|
@ -434,7 +431,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
|||
4. Deletes WorkflowDraftVariableFile records
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
session: Database connection
|
||||
file_ids: List of WorkflowDraftVariableFile IDs
|
||||
|
||||
Returns:
|
||||
|
|
@ -450,12 +447,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
|||
try:
|
||||
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
|
||||
query_sql = """
|
||||
SELECT wdvf.id, uf.key, uf.id as upload_file_id
|
||||
FROM workflow_draft_variable_files wdvf
|
||||
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
|
||||
WHERE wdvf.id IN :file_ids
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
|
||||
SELECT wdvf.id, uf.key, uf.id as upload_file_id
|
||||
FROM workflow_draft_variable_files wdvf
|
||||
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
|
||||
WHERE wdvf.id IN :file_ids \
|
||||
"""
|
||||
result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
|
||||
file_records = list(result)
|
||||
|
||||
# Delete from object storage and collect upload file IDs
|
||||
|
|
@ -473,17 +470,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
|||
# Delete UploadFile records
|
||||
if upload_file_ids:
|
||||
delete_upload_files_sql = """
|
||||
DELETE FROM upload_files
|
||||
WHERE id IN :upload_file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
|
||||
DELETE \
|
||||
FROM upload_files
|
||||
WHERE id IN :upload_file_ids \
|
||||
"""
|
||||
session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
|
||||
|
||||
# Delete WorkflowDraftVariableFile records
|
||||
delete_variable_files_sql = """
|
||||
DELETE FROM workflow_draft_variable_files
|
||||
WHERE id IN :file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
|
||||
DELETE \
|
||||
FROM workflow_draft_variable_files
|
||||
WHERE id IN :file_ids \
|
||||
"""
|
||||
session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
|
||||
|
||||
except Exception:
|
||||
logging.exception("Error deleting draft variable offload data:")
|
||||
|
|
@ -493,8 +492,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
|||
|
||||
|
||||
def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||
def del_app_trigger(trigger_id: str):
|
||||
db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
||||
def del_app_trigger(session, trigger_id: str):
|
||||
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
|
@ -505,8 +504,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||
def del_plugin_trigger(trigger_id: str):
|
||||
db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
||||
def del_plugin_trigger(session, trigger_id: str):
|
||||
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
|
|
@ -519,8 +518,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||
def del_webhook_trigger(trigger_id: str):
|
||||
db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
||||
def del_webhook_trigger(session, trigger_id: str):
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
|
|
@ -533,10 +532,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||
def del_schedule_plan(plan_id: str):
|
||||
db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
def del_schedule_plan(session, plan_id: str):
|
||||
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
|
@ -547,8 +544,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
|||
|
||||
|
||||
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
||||
def del_trigger_log(log_id: str):
|
||||
db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
||||
def del_trigger_log(session, log_id: str):
|
||||
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
|
@ -560,18 +557,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
|||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(query_sql), params)
|
||||
if rs.rowcount == 0:
|
||||
with session_factory.create_session() as session:
|
||||
rs = session.execute(sa.text(query_sql), params)
|
||||
rows = rs.fetchall()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for i in rs:
|
||||
for i in rows:
|
||||
record_id = str(i.id)
|
||||
try:
|
||||
delete_func(record_id)
|
||||
db.session.commit()
|
||||
delete_func(session, record_id)
|
||||
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Error occurred while deleting %s %s", name, record_id)
|
||||
continue
|
||||
# continue with next record even if one deletion fails
|
||||
session.rollback()
|
||||
break
|
||||
session.commit()
|
||||
|
||||
rs.close()
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import click
|
|||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Document, DocumentSegment
|
||||
|
|
@ -25,52 +25,55 @@ def remove_document_from_index_task(document_id: str):
|
|||
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).where(Document.id == document_id).first()
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
if document.indexing_status != "completed":
|
||||
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
|
||||
try:
|
||||
dataset = document.dataset
|
||||
try:
|
||||
dataset = document.dataset
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
except Exception:
|
||||
logger.exception("clean dataset %s from index failed", dataset.id)
|
||||
# update segment to disable
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
|
||||
{
|
||||
DocumentSegment.enabled: False,
|
||||
DocumentSegment.disabled_at: naive_utc_now(),
|
||||
DocumentSegment.disabled_by: document.disabled_by,
|
||||
DocumentSegment.updated_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
except Exception:
|
||||
logger.exception("clean dataset %s from index failed", dataset.id)
|
||||
# update segment to disable
|
||||
session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
|
||||
{
|
||||
DocumentSegment.enabled: False,
|
||||
DocumentSegment.disabled_at: naive_utc_now(),
|
||||
DocumentSegment.disabled_by: document.disabled_by,
|
||||
DocumentSegment.updated_at: naive_utc_now(),
|
||||
}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("remove document from index failed")
|
||||
if not document.archived:
|
||||
document.enabled = True
|
||||
db.session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
db.session.close()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Document removed from index: {document.id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("remove document from index failed")
|
||||
if not document.archived:
|
||||
document.enabled = True
|
||||
session.commit()
|
||||
finally:
|
||||
redis_client.delete(indexing_cache_key)
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant
|
||||
|
|
@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
|||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
user = session.query(Account).where(Account.id == user_id).first()
|
||||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = f"document_{document_id}_is_retried"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = f"document_{document_id}_is_retried"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
document = (
|
||||
session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
|
||||
return
|
||||
try:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
||||
if dataset.runtime_mode == "rag_pipeline":
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.retry_error_document(dataset, document, user)
|
||||
else:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
|
||||
)
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
|
||||
return
|
||||
try:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
if dataset.runtime_mode == "rag_pipeline":
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.retry_error_document(dataset, document, user)
|
||||
else:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import time
|
|||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
|
@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
|||
"""
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
raise ValueError("Dataset not found")
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
sync_indexing_cache_key = f"document_{document_id}_is_sync"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
sync_indexing_cache_key = f"document_{document_id}_is_sync"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
|
||||
return
|
||||
try:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
|
||||
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
if not document:
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
|
||||
return
|
||||
try:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(ex)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
|
||||
session.add(document)
|
||||
session.commit()
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
|
|
@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager
|
|||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.ext_database import db
|
||||
from models.enums import (
|
||||
AppTriggerType,
|
||||
CreatorUserRole,
|
||||
|
|
@ -257,7 +257,7 @@ def dispatch_triggered_workflow(
|
|||
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
trigger_entity: TriggerProviderEntity = provider_controller.entity
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
|
||||
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ from celery import shared_task
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.utils.locks import build_trigger_refresh_lock_key
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
|
@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
|
|||
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
|
||||
try:
|
||||
now: int = _now_ts()
|
||||
with Session(db.engine) as session:
|
||||
with session_factory.create_session() as session:
|
||||
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
|
||||
|
||||
if not subscription:
|
||||
|
|
|
|||
|
|
@ -10,11 +10,10 @@ import logging
|
|||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
|
@ -46,10 +45,7 @@ def save_workflow_execution_task(
|
|||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
with session_factory.create_session() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowExecution.model_validate(execution_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,13 +10,12 @@ import logging
|
|||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
@ -48,10 +47,7 @@ def save_workflow_node_execution_task(
|
|||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
with session_factory.create_session() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.nodes.trigger_schedule.exc import (
|
||||
ScheduleExecutionError,
|
||||
ScheduleNotFoundError,
|
||||
TenantOwnerNotFoundError,
|
||||
)
|
||||
from enums.quota_type import QuotaType, unlimited
|
||||
from extensions.ext_database import db
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError
|
||||
|
|
@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
|
|||
TenantOwnerNotFoundError: If no owner/admin for tenant
|
||||
ScheduleExecutionError: If workflow trigger fails
|
||||
"""
|
||||
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
with session_factory.create_session() as session:
|
||||
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||
if not schedule:
|
||||
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
|
|
@ -16,26 +16,23 @@ from tasks.remove_app_and_related_data_task import _delete_draft_variables, dele
|
|||
@pytest.fixture
|
||||
def app_and_tenant(flask_req_ctx):
|
||||
tenant_id = uuid.uuid4()
|
||||
tenant = Tenant(
|
||||
id=tenant_id,
|
||||
name="test_tenant",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
with session_factory.create_session() as session:
|
||||
tenant = Tenant(name="test_tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant_id, # Now tenant.id will have a value
|
||||
name=f"Test App for tenant {tenant.id}",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.flush()
|
||||
yield (tenant, app)
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"Test App for tenant {tenant.id}",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
# Cleanup with proper error handling
|
||||
db.session.delete(app)
|
||||
db.session.delete(tenant)
|
||||
# return detached objects (ids will be used by tests)
|
||||
return (tenant, app)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesIntegration:
|
||||
|
|
@ -44,334 +41,285 @@ class TestDeleteDraftVariablesIntegration:
|
|||
"""Create test data with apps and draft variables."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create a second app for testing
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.commit()
|
||||
|
||||
# Create draft variables for both apps
|
||||
variables_app1 = []
|
||||
variables_app2 = []
|
||||
|
||||
for i in range(5):
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
with session_factory.create_session() as session:
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(var1)
|
||||
variables_app1.append(var1)
|
||||
session.add(app2)
|
||||
session.flush()
|
||||
|
||||
var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var2)
|
||||
variables_app2.append(var2)
|
||||
variables_app1 = []
|
||||
variables_app2 = []
|
||||
for i in range(5):
|
||||
var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var1)
|
||||
variables_app1.append(var1)
|
||||
|
||||
# Commit all the variables to the database
|
||||
db.session.commit()
|
||||
var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var2)
|
||||
variables_app2.append(var2)
|
||||
session.commit()
|
||||
|
||||
app2_id = app2.id
|
||||
|
||||
yield {
|
||||
"app1": app,
|
||||
"app2": app2,
|
||||
"app2": App(id=app2_id), # dummy with id to avoid open session
|
||||
"tenant": tenant,
|
||||
"variables_app1": variables_app1,
|
||||
"variables_app2": variables_app2,
|
||||
}
|
||||
|
||||
# Cleanup - refresh session and check if objects still exist
|
||||
db.session.rollback() # Clear any pending changes
|
||||
|
||||
# Clean up remaining variables
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
|
||||
with session_factory.create_session() as session:
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id.in_([app.id, app2_id]))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
# Clean up app2
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
|
||||
db.session.commit()
|
||||
session.execute(cleanup_query)
|
||||
app2_obj = session.get(App, app2_id)
|
||||
if app2_obj:
|
||||
session.delete(app2_obj)
|
||||
session.commit()
|
||||
|
||||
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
|
||||
"""Test that batch deletion only removes variables for the specified app."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
app2_id = data["app2"].id
|
||||
|
||||
# Verify initial state
|
||||
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
assert app1_vars_before == 5
|
||||
assert app2_vars_before == 5
|
||||
|
||||
# Delete app1 variables
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
|
||||
assert app1_vars_after == 0 # All app1 variables deleted
|
||||
assert app2_vars_after == 5 # App2 variables unchanged
|
||||
with session_factory.create_session() as session:
|
||||
app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
assert app1_vars_after == 0
|
||||
assert app2_vars_after == 5
|
||||
|
||||
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
|
||||
"""Test batch deletion with small batch size processes all records."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
|
||||
|
||||
assert deleted_count == 5
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
|
||||
"""Test that deleting variables for nonexistent app returns 0."""
|
||||
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
|
||||
|
||||
nonexistent_app_id = str(uuid.uuid4())
|
||||
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
|
||||
|
||||
assert deleted_count == 0
|
||||
|
||||
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
|
||||
"""Test that _delete_draft_variables wrapper function works correctly."""
|
||||
data = setup_test_data
|
||||
app1_id = data["app1"].id
|
||||
|
||||
# Verify initial state
|
||||
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_before == 5
|
||||
|
||||
# Call wrapper function
|
||||
deleted_count = _delete_draft_variables(app1_id)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 5
|
||||
|
||||
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
assert vars_after == 0
|
||||
|
||||
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
|
||||
"""Test batch deletion with larger dataset to verify batching logic."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create many draft variables
|
||||
variables = []
|
||||
for i in range(25):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
variables.append(var)
|
||||
variable_ids = [i.id for i in variables]
|
||||
|
||||
# Commit the variables to the database
|
||||
db.session.commit()
|
||||
variable_ids: list[str] = []
|
||||
with session_factory.create_session() as session:
|
||||
variables = []
|
||||
for i in range(25):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var)
|
||||
variables.append(var)
|
||||
session.commit()
|
||||
variable_ids = [v.id for v in variables]
|
||||
|
||||
try:
|
||||
# Use small batch size to force multiple batches
|
||||
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
|
||||
|
||||
assert deleted_count == 25
|
||||
|
||||
# Verify all variables are deleted
|
||||
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
|
||||
assert remaining_vars == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
|
||||
assert remaining == 0
|
||||
finally:
|
||||
query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.id.in_(variable_ids),
|
||||
with session_factory.create_session() as session:
|
||||
query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.id.in_(variable_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(query)
|
||||
session.execute(query)
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
"""Integration tests for draft variable deletion with Offload data."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with draft variables that have associated Offload files."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create UploadFile records
|
||||
from core.variables.types import SegmentType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file1)
|
||||
db.session.add(upload_file2)
|
||||
db.session.flush()
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
session.add(upload_file1)
|
||||
session.add(upload_file2)
|
||||
session.flush()
|
||||
|
||||
# Create WorkflowDraftVariableFile records
|
||||
from core.variables.types import SegmentType
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
session.add(var_file1)
|
||||
session.add(var_file2)
|
||||
session.flush()
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
db.session.add(var_file1)
|
||||
db.session.add(var_file2)
|
||||
db.session.flush()
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(draft_var1)
|
||||
session.add(draft_var2)
|
||||
session.add(draft_var3)
|
||||
session.commit()
|
||||
|
||||
# Create WorkflowDraftVariable records with file associations
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
# Create a regular variable without Offload data
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
data = {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
db.session.add(draft_var1)
|
||||
db.session.add(draft_var2)
|
||||
db.session.add(draft_var3)
|
||||
db.session.commit()
|
||||
yield data
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
db.session.rollback()
|
||||
|
||||
# Clean up any remaining records
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
|
||||
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
db.session.commit()
|
||||
with session_factory.create_session() as session:
|
||||
session.rollback()
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
|
||||
(UploadFile, [uf.id for uf in data["upload_files"]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that deleting draft variables also cleans up associated Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to succeed
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = db.session.query(UploadFile).count()
|
||||
|
||||
assert draft_vars_before == 3 # 2 with files + 1 regular
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = session.query(UploadFile).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 3
|
||||
|
||||
# Check that all draft variables are deleted
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Check that associated Offload data is cleaned up
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
assert var_files_after == 0 # All variable files should be deleted
|
||||
assert upload_files_after == 0 # All upload files should be deleted
|
||||
|
||||
# Verify storage deletion was called for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert "test/file1.json" in storage_keys_deleted
|
||||
|
|
@ -379,92 +327,71 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that database cleanup continues even when storage deletion fails."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to fail for first file, succeed for second
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify that all draft variables are still deleted
|
||||
assert deleted_count == 3
|
||||
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Database cleanup should still succeed even with storage errors
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage deletion was attempted for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test deletion with mix of variables with and without Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Create additional app with only regular variables (no offload data)
|
||||
tenant = data["tenant"]
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.flush()
|
||||
|
||||
# Add regular variables to app2
|
||||
regular_vars = []
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
with session_factory.create_session() as session:
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(var)
|
||||
regular_vars.append(var)
|
||||
db.session.commit()
|
||||
session.add(app2)
|
||||
session.flush()
|
||||
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var)
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
# Mock storage deletion
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Delete variables for app2 (no offload data)
|
||||
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
|
||||
assert deleted_count_app2 == 3
|
||||
|
||||
# Verify storage wasn't called for app2 (no offload files)
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
# Delete variables for original app (with offload data)
|
||||
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count_app1 == 3
|
||||
|
||||
# Now storage should be called for the offload files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
finally:
|
||||
# Cleanup app2 and its variables
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_vars_query)
|
||||
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
db.session.commit()
|
||||
with session_factory.create_session() as session:
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(cleanup_vars_query)
|
||||
app2_obj = session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
session.delete(app2_obj)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -39,23 +39,22 @@ class TestCleanDatasetTask:
|
|||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
"""Clean up database before each test to ensure isolation."""
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Clear all test data
|
||||
db.session.query(DatasetMetadataBinding).delete()
|
||||
db.session.query(DatasetMetadata).delete()
|
||||
db.session.query(AppDatasetJoin).delete()
|
||||
db.session.query(DatasetQuery).delete()
|
||||
db.session.query(DatasetProcessRule).delete()
|
||||
db.session.query(DocumentSegment).delete()
|
||||
db.session.query(Document).delete()
|
||||
db.session.query(Dataset).delete()
|
||||
db.session.query(UploadFile).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.commit()
|
||||
# Clear all test data using the provided session fixture
|
||||
db_session_with_containers.query(DatasetMetadataBinding).delete()
|
||||
db_session_with_containers.query(DatasetMetadata).delete()
|
||||
db_session_with_containers.query(AppDatasetJoin).delete()
|
||||
db_session_with_containers.query(DatasetQuery).delete()
|
||||
db_session_with_containers.query(DatasetProcessRule).delete()
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(UploadFile).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
redis_client.flushdb()
|
||||
|
|
@ -103,10 +102,8 @@ class TestCleanDatasetTask:
|
|||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
|
|
@ -115,8 +112,8 @@ class TestCleanDatasetTask:
|
|||
status="active",
|
||||
)
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account relationship
|
||||
tenant_account_join = TenantAccountJoin(
|
||||
|
|
@ -125,8 +122,8 @@ class TestCleanDatasetTask:
|
|||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
|
||||
db.session.add(tenant_account_join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant_account_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return account, tenant
|
||||
|
||||
|
|
@ -155,10 +152,8 @@ class TestCleanDatasetTask:
|
|||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
@ -194,10 +189,8 @@ class TestCleanDatasetTask:
|
|||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
|
|
@ -232,10 +225,8 @@ class TestCleanDatasetTask:
|
|||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return segment
|
||||
|
||||
|
|
@ -267,10 +258,8 @@ class TestCleanDatasetTask:
|
|||
used=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
|
|
@ -302,31 +291,29 @@ class TestCleanDatasetTask:
|
|||
)
|
||||
|
||||
# Verify results
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Check that dataset-related data was cleaned up
|
||||
documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(documents) == 0
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Check that metadata and bindings were cleaned up
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(metadata) == 0
|
||||
|
||||
bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(bindings) == 0
|
||||
|
||||
# Check that process rules and queries were cleaned up
|
||||
process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
|
||||
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(process_rules) == 0
|
||||
|
||||
queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
|
||||
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(queries) == 0
|
||||
|
||||
# Check that app dataset joins were cleaned up
|
||||
app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
|
||||
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(app_joins) == 0
|
||||
|
||||
# Verify index processor was called
|
||||
|
|
@ -378,9 +365,7 @@ class TestCleanDatasetTask:
|
|||
import json
|
||||
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset metadata and bindings
|
||||
metadata = DatasetMetadata(
|
||||
|
|
@ -403,11 +388,9 @@ class TestCleanDatasetTask:
|
|||
binding.id = str(uuid.uuid4())
|
||||
binding.created_at = datetime.now()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(metadata)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(metadata)
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the task
|
||||
clean_dataset_task(
|
||||
|
|
@ -421,22 +404,24 @@ class TestCleanDatasetTask:
|
|||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that metadata and bindings were cleaned up
|
||||
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_bindings = (
|
||||
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
assert len(remaining_bindings) == 0
|
||||
|
||||
# Verify index processor was called
|
||||
|
|
@ -489,12 +474,13 @@ class TestCleanDatasetTask:
|
|||
mock_index_processor.clean.assert_called_once()
|
||||
|
||||
# Check that all data was cleaned up
|
||||
from extensions.ext_database import db
|
||||
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Recreate data for next test case
|
||||
|
|
@ -540,14 +526,13 @@ class TestCleanDatasetTask:
|
|||
)
|
||||
|
||||
# Verify results - even with vector cleanup failure, documents and segments should be deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Check that documents were still deleted despite vector cleanup failure
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that segments were still deleted despite vector cleanup failure
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Verify that index processor was called and failed
|
||||
|
|
@ -608,10 +593,8 @@ class TestCleanDatasetTask:
|
|||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the get_image_upload_file_ids function to return our image file IDs
|
||||
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
|
||||
|
|
@ -629,16 +612,18 @@ class TestCleanDatasetTask:
|
|||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all image files were deleted from database
|
||||
image_file_ids = [f.id for f in image_files]
|
||||
remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
|
||||
remaining_image_files = (
|
||||
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
|
||||
)
|
||||
assert len(remaining_image_files) == 0
|
||||
|
||||
# Verify that storage.delete was called for each image file
|
||||
|
|
@ -745,22 +730,24 @@ class TestCleanDatasetTask:
|
|||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that all metadata and bindings were deleted
|
||||
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_bindings = (
|
||||
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
|
||||
)
|
||||
assert len(remaining_bindings) == 0
|
||||
|
||||
# Verify performance expectations
|
||||
|
|
@ -808,9 +795,7 @@ class TestCleanDatasetTask:
|
|||
import json
|
||||
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock storage to raise exceptions
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
|
@ -827,18 +812,13 @@ class TestCleanDatasetTask:
|
|||
)
|
||||
|
||||
# Verify results
|
||||
# Check that documents were still deleted despite storage failure
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that segments were still deleted despite storage failure
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
# Note: When storage operations fail, database deletions may be rolled back by implementation.
|
||||
# This test focuses on ensuring the task handles the exception and continues execution/logging.
|
||||
|
||||
# Check that upload file was still deleted from database despite storage failure
|
||||
# Note: When storage operations fail, the upload file may not be deleted
|
||||
# This demonstrates that the cleanup process continues even with storage errors
|
||||
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
|
||||
# The upload file should still be deleted from the database even if storage cleanup fails
|
||||
# However, this depends on the specific implementation of clean_dataset_task
|
||||
if len(remaining_files) > 0:
|
||||
|
|
@ -890,10 +870,8 @@ class TestCleanDatasetTask:
|
|||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create document with special characters in name
|
||||
special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~"
|
||||
|
|
@ -912,8 +890,8 @@ class TestCleanDatasetTask:
|
|||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create segment with special characters and very long content
|
||||
long_content = "Very long content " * 100 # Long content within reasonable limits
|
||||
|
|
@ -934,8 +912,8 @@ class TestCleanDatasetTask:
|
|||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create upload file with special characters in name
|
||||
special_filename = f"test_file_{special_content}.txt"
|
||||
|
|
@ -952,14 +930,14 @@ class TestCleanDatasetTask:
|
|||
created_at=datetime.now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update document with file reference
|
||||
import json
|
||||
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Save upload file ID for verification
|
||||
upload_file_id = upload_file.id
|
||||
|
|
@ -975,8 +953,8 @@ class TestCleanDatasetTask:
|
|||
special_metadata.id = str(uuid.uuid4())
|
||||
special_metadata.created_at = datetime.now()
|
||||
|
||||
db.session.add(special_metadata)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(special_metadata)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the task
|
||||
clean_dataset_task(
|
||||
|
|
@ -990,19 +968,19 @@ class TestCleanDatasetTask:
|
|||
|
||||
# Verify results
|
||||
# Check that all documents were deleted
|
||||
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_documents) == 0
|
||||
|
||||
# Check that all segments were deleted
|
||||
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_segments) == 0
|
||||
|
||||
# Check that all upload files were deleted
|
||||
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
|
||||
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
|
||||
assert len(remaining_files) == 0
|
||||
|
||||
# Check that all metadata was deleted
|
||||
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(remaining_metadata) == 0
|
||||
|
||||
# Verify that storage.delete was called
|
||||
|
|
|
|||
|
|
@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask:
|
|||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
"""Clean up database and Redis before each test to ensure isolation."""
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Clear all test data
|
||||
db.session.query(DocumentSegment).delete()
|
||||
db.session.query(Document).delete()
|
||||
db.session.query(Dataset).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.commit()
|
||||
# Clear all test data using fixture session
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
redis_client.flushdb()
|
||||
|
|
@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask:
|
|||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
|
|
@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask:
|
|||
status="normal",
|
||||
plan="basic",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join with owner role
|
||||
join = TenantAccountJoin(
|
||||
|
|
@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask:
|
|||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
|
@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask:
|
|||
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
|
||||
)
|
||||
|
||||
# Mock global database session to simulate transaction issues
|
||||
from extensions.ext_database import db
|
||||
|
||||
original_commit = db.session.commit
|
||||
commit_called = False
|
||||
|
||||
def mock_commit():
|
||||
nonlocal commit_called
|
||||
if not commit_called:
|
||||
commit_called = True
|
||||
raise Exception("Database commit failed")
|
||||
return original_commit()
|
||||
|
||||
db.session.commit = mock_commit
|
||||
# Simulate an error during indexing to trigger rollback path
|
||||
mock_processor = mock_external_service_dependencies["index_processor"]
|
||||
mock_processor.load.side_effect = Exception("Simulated indexing error")
|
||||
|
||||
# Act: Execute the task
|
||||
create_segment_to_index_task(segment.id)
|
||||
|
|
@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask:
|
|||
assert segment.disabled_at is not None
|
||||
assert segment.error is not None
|
||||
|
||||
# Restore original commit method
|
||||
db.session.commit = original_commit
|
||||
|
||||
def test_create_segment_to_index_metadata_validation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
|
|
|||
|
|
@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask:
|
|||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant = tenant
|
||||
|
|
@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask:
|
|||
built_in_field_enabled=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask:
|
|||
document.archived = False
|
||||
document.doc_form = "text_model" # Use text_model form for testing
|
||||
document.doc_language = "en"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
|
|
@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask:
|
|||
|
||||
segments.append(segment)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
for segment in segments:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return segments
|
||||
|
||||
|
|
@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask:
|
|||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
mock_redis.delete.return_value = True
|
||||
|
||||
# Mock db.session.close to verify it's called
|
||||
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
# Act
|
||||
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
|
||||
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Verify session was closed
|
||||
mock_close.assert_called()
|
||||
# Assert
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Session lifecycle is managed by context manager; no explicit close assertion
|
||||
|
||||
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from faker import Faker
|
|||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_task import (
|
||||
|
|
@ -75,15 +74,15 @@ class TestDocumentIndexingTasks:
|
|||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
|
|
@ -92,8 +91,8 @@ class TestDocumentIndexingTasks:
|
|||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
|
|
@ -105,8 +104,8 @@ class TestDocumentIndexingTasks:
|
|||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
|
|
@ -124,13 +123,13 @@ class TestDocumentIndexingTasks:
|
|||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
|
|
@ -157,15 +156,15 @@ class TestDocumentIndexingTasks:
|
|||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
|
|
@ -174,8 +173,8 @@ class TestDocumentIndexingTasks:
|
|||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
|
|
@ -187,8 +186,8 @@ class TestDocumentIndexingTasks:
|
|||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
|
|
@ -206,10 +205,10 @@ class TestDocumentIndexingTasks:
|
|||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
|
|
@ -219,7 +218,7 @@ class TestDocumentIndexingTasks:
|
|||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
|
|
@ -242,6 +241,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the task
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
|
|
@ -250,7 +252,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify documents were updated to parsing status
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -310,6 +312,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the task with mixed document IDs
|
||||
_document_indexing(dataset.id, all_document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
|
@ -317,7 +322,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify only existing documents were updated
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -353,6 +358,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the task
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
|
|
@ -361,7 +369,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -400,7 +408,7 @@ class TestDocumentIndexingTasks:
|
|||
indexing_status="completed", # Already completed
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(doc1)
|
||||
db_session_with_containers.add(doc1)
|
||||
extra_documents.append(doc1)
|
||||
|
||||
# Document with disabled status
|
||||
|
|
@ -417,10 +425,10 @@ class TestDocumentIndexingTasks:
|
|||
indexing_status="waiting",
|
||||
enabled=False, # Disabled
|
||||
)
|
||||
db.session.add(doc2)
|
||||
db_session_with_containers.add(doc2)
|
||||
extra_documents.append(doc2)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
all_documents = base_documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
|
@ -428,6 +436,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the task with mixed document states
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
|
@ -435,7 +446,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify all documents were updated to parsing status
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -482,20 +493,23 @@ class TestDocumentIndexingTasks:
|
|||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
extra_documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
all_documents = documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error
|
||||
|
|
@ -526,6 +540,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the task with billing disabled
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify successful processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
|
@ -533,7 +550,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify documents were updated to parsing status
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -565,6 +582,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the task
|
||||
_document_indexing(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
|
|
@ -573,7 +593,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -674,6 +694,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify core processing occurred (same as _document_indexing)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
|
@ -681,7 +704,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify documents were updated (same as _document_indexing)
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -794,6 +817,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the wrapper function
|
||||
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error was handled gracefully
|
||||
# The function should not raise exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
|
|
@ -802,7 +828,7 @@ class TestDocumentIndexingTasks:
|
|||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _document_indexing uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -865,6 +891,9 @@ class TestDocumentIndexingTasks:
|
|||
# Act: Execute the wrapper function for tenant1 only
|
||||
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify core processing occurred for tenant1
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import pytest
|
|||
from faker import Faker
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
|
|
@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
|
|
@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
|
|
@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
|
|
@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
|
|
@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
indexing_at=fake.date_time_this_year(),
|
||||
created_by=dataset.created_by, # Add required field
|
||||
)
|
||||
db.session.add(segment)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh to ensure all relationships are loaded
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
|
||||
return dataset, documents, segments
|
||||
|
||||
|
|
@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
|
|
@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
|
|
@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
|
|
@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
|
|
@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
|
|
@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
|
|
@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
# Verify documents were updated to parsing status
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
segment_ids = [seg.id for seg in segments]
|
||||
|
||||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify segment cleanup
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify segment cleanup
|
||||
# Verify index processor clean was called for each document with segments
|
||||
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
|
||||
|
||||
# Verify segments were deleted from database
|
||||
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
|
||||
for segment in segments:
|
||||
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
# Re-query segments from database using captured IDs to avoid stale ORM instances
|
||||
for seg_id in segment_ids:
|
||||
deleted_segment = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first()
|
||||
)
|
||||
assert deleted_segment is None
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
# Act: Execute the task with mixed document IDs
|
||||
_duplicate_document_indexing_task(dataset.id, all_document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
|
@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
# Verify only existing documents were updated
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in existing_document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
# Act: Execute the task
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
|
|
@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
# Verify documents were still updated to parsing status before the exception
|
||||
# Re-query documents from database since _duplicate_document_indexing_task close the session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
|
|
@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
extra_documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
all_documents = documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "batch upload" in updated_document.error.lower()
|
||||
|
|
@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
# Act: Execute the task with documents that will exceed vector space limit
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Ensure we see committed changes from a different session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert: Verify error handling
|
||||
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert updated_document.error is not None
|
||||
assert "limit" in updated_document.error.lower()
|
||||
|
|
@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
|
|
@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
mock_queue.delete_task_key.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
|
|
@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks:
|
|||
mock_queue.delete_task_key.assert_called_once()
|
||||
|
||||
# Clear session cache to see database updates from task's session
|
||||
db.session.expire_all()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Verify documents were processed
|
||||
for doc_id in document_ids:
|
||||
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
|
|
|
|||
|
|
@ -49,10 +49,14 @@ def pipeline_id():
|
|||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session with query capabilities."""
|
||||
with patch("tasks.clean_dataset_task.db") as mock_db:
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
# context manager for create_session()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = mock_session
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
# Setup query chain
|
||||
mock_query = MagicMock()
|
||||
|
|
@ -66,7 +70,10 @@ def mock_db_session():
|
|||
# Setup execute for JOIN queries
|
||||
mock_session.execute.return_value.all.return_value = []
|
||||
|
||||
yield mock_db
|
||||
# Yield an object with a `.session` attribute to keep tests unchanged
|
||||
wrapper = MagicMock()
|
||||
wrapper.session = mock_session
|
||||
yield wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -227,7 +234,9 @@ class TestBasicCleanup:
|
|||
|
||||
# Assert
|
||||
mock_db_session.session.delete.assert_any_call(mock_document)
|
||||
mock_db_session.session.delete.assert_any_call(mock_segment)
|
||||
# Segments are deleted in batch; verify a DELETE on document_segments was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_clean_dataset_task_deletes_related_records(
|
||||
|
|
@ -413,7 +422,9 @@ class TestErrorHandling:
|
|||
|
||||
# Assert - documents and segments should still be deleted
|
||||
mock_db_session.session.delete.assert_any_call(mock_document)
|
||||
mock_db_session.session.delete.assert_any_call(mock_segment)
|
||||
# Segments are deleted in batch; verify a DELETE on document_segments was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_clean_dataset_task_storage_delete_failure_continues(
|
||||
|
|
@ -461,7 +472,7 @@ class TestErrorHandling:
|
|||
[mock_segment], # segments
|
||||
]
|
||||
mock_get_image_upload_file_ids.return_value = [image_file_id]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
|
||||
mock_storage.delete.side_effect = Exception("Storage service unavailable")
|
||||
|
||||
# Act
|
||||
|
|
@ -476,8 +487,9 @@ class TestErrorHandling:
|
|||
|
||||
# Assert - storage delete was attempted for image file
|
||||
mock_storage.delete.assert_called_with(mock_upload_file.key)
|
||||
# Image file should still be deleted from database
|
||||
mock_db_session.session.delete.assert_any_call(mock_upload_file)
|
||||
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_database_error_rollback(
|
||||
self,
|
||||
|
|
@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup:
|
|||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_with(mock_attachment_file.key)
|
||||
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
|
||||
mock_db_session.session.delete.assert_any_call(mock_binding)
|
||||
# Attachment file and binding are deleted in batch; verify DELETEs were issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_attachment_storage_failure(
|
||||
self,
|
||||
|
|
@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup:
|
|||
|
||||
# Assert - storage delete was attempted
|
||||
mock_storage.delete.assert_called_once()
|
||||
# Records should still be deleted from database
|
||||
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
|
||||
mock_db_session.session.delete.assert_any_call(mock_binding)
|
||||
# Records are deleted in batch; verify DELETEs were issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -784,7 +799,7 @@ class TestUploadFileCleanup:
|
|||
[mock_document], # documents
|
||||
[], # segments
|
||||
]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
|
||||
|
||||
# Act
|
||||
clean_dataset_task(
|
||||
|
|
@ -798,7 +813,9 @@ class TestUploadFileCleanup:
|
|||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_with(mock_upload_file.key)
|
||||
mock_db_session.session.delete.assert_any_call(mock_upload_file)
|
||||
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_handles_missing_upload_file(
|
||||
self,
|
||||
|
|
@ -832,7 +849,7 @@ class TestUploadFileCleanup:
|
|||
[mock_document], # documents
|
||||
[], # segments
|
||||
]
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
# Act - should not raise exception
|
||||
clean_dataset_task(
|
||||
|
|
@ -949,11 +966,11 @@ class TestImageFileCleanup:
|
|||
[mock_segment], # segments
|
||||
]
|
||||
|
||||
# Setup a mock query chain that returns files in sequence
|
||||
# Setup a mock query chain that returns files in batch (align with .in_().all())
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_where.first.side_effect = mock_image_files
|
||||
mock_where.all.return_value = mock_image_files
|
||||
mock_db_session.session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
|
|
@ -966,10 +983,10 @@ class TestImageFileCleanup:
|
|||
doc_form="paragraph_index",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mock_storage.delete.call_count == 2
|
||||
mock_storage.delete.assert_any_call("images/image-1.jpg")
|
||||
mock_storage.delete.assert_any_call("images/image-2.jpg")
|
||||
# Assert - each expected image key was deleted at least once
|
||||
calls = [c.args[0] for c in mock_storage.delete.call_args_list]
|
||||
assert "images/image-1.jpg" in calls
|
||||
assert "images/image-2.jpg" in calls
|
||||
|
||||
def test_clean_dataset_task_handles_missing_image_file(
|
||||
self,
|
||||
|
|
@ -1010,7 +1027,7 @@ class TestImageFileCleanup:
|
|||
]
|
||||
|
||||
# Image file not found
|
||||
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
|
||||
|
||||
# Act - should not raise exception
|
||||
clean_dataset_task(
|
||||
|
|
@ -1086,14 +1103,15 @@ class TestEdgeCases:
|
|||
doc_form="paragraph_index",
|
||||
)
|
||||
|
||||
# Assert - all documents and segments should be deleted
|
||||
# Assert - all documents and segments should be deleted (documents per-entity, segments in batch)
|
||||
delete_calls = mock_db_session.session.delete.call_args_list
|
||||
deleted_items = [call[0][0] for call in delete_calls]
|
||||
|
||||
for doc in mock_documents:
|
||||
assert doc in deleted_items
|
||||
for seg in mock_segments:
|
||||
assert seg in deleted_items
|
||||
# Verify a batch DELETE on document_segments occurred
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
def test_clean_dataset_task_document_with_empty_data_source_info(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id):
|
|||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task
|
|||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock the db.session used in delete_account_task."""
|
||||
with patch("tasks.delete_account_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
yield mock_session
|
||||
"""Mock session via session_factory.create_session()."""
|
||||
with patch("tasks.delete_account_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -109,13 +109,25 @@ def mock_document_segments(document_id):
|
|||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests can observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask:
|
|||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# No session operations should be performed beyond the initial query
|
||||
mock_db_session.close.assert_not_called()
|
||||
# Session should still be closed via context manager teardown
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
|
|
@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask:
|
|||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
|
|
|||
|
|
@ -94,13 +94,25 @@ def mock_document_segments(document_ids):
|
|||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session."""
|
||||
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_session.scalars.return_value = MagicMock()
|
||||
yield mock_session
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Allow tests to observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore:
|
|||
):
|
||||
"""Test successful duplicate document indexing flow."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# Dataset via query.first()
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# scalars() call sequence:
|
||||
# 1) documents list
|
||||
# 2..N) segments per document
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
# First call returns documents; subsequent calls return segments
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
|
@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore:
|
|||
):
|
||||
"""Test duplicate document indexing when billing limit is exceeded."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# First scalars() -> documents; subsequent -> empty segments
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.TEAM
|
||||
|
|
@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
|||
):
|
||||
"""Test duplicate document indexing when IndexingRunner raises an error."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
|
|
@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
|||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
|
|
@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore:
|
|||
):
|
||||
"""Test that duplicate document indexing cleans old segments."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
|
||||
# Act
|
||||
|
|
@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore:
|
|||
# Verify clean was called for each document
|
||||
assert mock_processor.clean.call_count == len(mock_documents)
|
||||
|
||||
# Verify segments were deleted
|
||||
for segment in mock_document_segments:
|
||||
mock_db_session.delete.assert_any_call(segment)
|
||||
# Verify segments were deleted in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
|
|||
|
|
@ -11,21 +11,18 @@ from tasks.remove_app_and_related_data_task import (
|
|||
|
||||
class TestDeleteDraftVariablesBatch:
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup):
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
|
||||
"""Test successful deletion of draft variables in batches."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 100
|
||||
|
||||
# Mock database connection and engine
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock two batches of results, then empty
|
||||
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
|
||||
|
|
@ -68,7 +65,7 @@ class TestDeleteDraftVariablesBatch:
|
|||
select_result3.__iter__.return_value = iter([])
|
||||
|
||||
# Configure side effects in the correct order
|
||||
mock_conn.execute.side_effect = [
|
||||
mock_session.execute.side_effect = [
|
||||
select_result1, # First SELECT
|
||||
delete_result1, # First DELETE
|
||||
select_result2, # Second SELECT
|
||||
|
|
@ -86,54 +83,49 @@ class TestDeleteDraftVariablesBatch:
|
|||
assert result == 150
|
||||
|
||||
# Verify database calls
|
||||
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
|
||||
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
|
||||
|
||||
# Verify offload cleanup was called for both batches with file_ids
|
||||
expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)]
|
||||
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
|
||||
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
|
||||
|
||||
# Simplified verification - check that the right number of calls were made
|
||||
# and that the SQL queries contain the expected patterns
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
actual_calls = mock_session.execute.call_args_list
|
||||
for i, actual_call in enumerate(actual_calls):
|
||||
sql_text = str(actual_call[0][0])
|
||||
normalized = " ".join(sql_text.split())
|
||||
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
|
||||
# Verify it's a SELECT query that now includes file_id
|
||||
sql_text = str(actual_call[0][0])
|
||||
assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE app_id = :app_id" in sql_text
|
||||
assert "LIMIT :batch_size" in sql_text
|
||||
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
|
||||
assert "WHERE app_id = :app_id" in normalized
|
||||
assert "LIMIT :batch_size" in normalized
|
||||
else: # DELETE calls (odd indices: 1, 3)
|
||||
# Verify it's a DELETE query
|
||||
sql_text = str(actual_call[0][0])
|
||||
assert "DELETE FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE id IN :ids" in sql_text
|
||||
assert "DELETE FROM workflow_draft_variables" in normalized
|
||||
assert "WHERE id IN :ids" in normalized
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup):
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
|
||||
"""Test deletion when no draft variables exist for the app."""
|
||||
app_id = "nonexistent-app-id"
|
||||
batch_size = 1000
|
||||
|
||||
# Mock database connection
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock empty result
|
||||
empty_result = MagicMock()
|
||||
empty_result.__iter__.return_value = iter([])
|
||||
mock_conn.execute.return_value = empty_result
|
||||
mock_session.execute.return_value = empty_result
|
||||
|
||||
result = delete_draft_variables_batch(app_id, batch_size)
|
||||
|
||||
assert result == 0
|
||||
assert mock_conn.execute.call_count == 1 # Only one select query
|
||||
assert mock_session.execute.call_count == 1 # Only one select query
|
||||
mock_offload_cleanup.assert_not_called() # No files to clean up
|
||||
|
||||
def test_delete_draft_variables_batch_invalid_batch_size(self):
|
||||
|
|
@ -147,22 +139,19 @@ class TestDeleteDraftVariablesBatch:
|
|||
delete_draft_variables_batch(app_id, 0)
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
@patch("tasks.remove_app_and_related_data_task.session_factory")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup):
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
|
||||
"""Test that batch deletion logs progress correctly."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 50
|
||||
|
||||
# Mock database
|
||||
mock_conn = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
# Properly mock the context manager
|
||||
# Mock session via session_factory
|
||||
mock_session = MagicMock()
|
||||
mock_context_manager = MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock_conn
|
||||
mock_context_manager.__enter__.return_value = mock_session
|
||||
mock_context_manager.__exit__.return_value = None
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
mock_sf.create_session.return_value = mock_context_manager
|
||||
|
||||
# Mock one batch then empty
|
||||
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
|
||||
|
|
@ -183,7 +172,7 @@ class TestDeleteDraftVariablesBatch:
|
|||
empty_result = MagicMock()
|
||||
empty_result.__iter__.return_value = iter([])
|
||||
|
||||
mock_conn.execute.side_effect = [
|
||||
mock_session.execute.side_effect = [
|
||||
# Select query result
|
||||
select_result,
|
||||
# Delete query result
|
||||
|
|
@ -201,7 +190,7 @@ class TestDeleteDraftVariablesBatch:
|
|||
|
||||
# Verify offload cleanup was called with file_ids
|
||||
if batch_file_ids:
|
||||
mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids)
|
||||
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
|
||||
|
||||
# Verify logging calls
|
||||
assert mock_logging.info.call_count == 2
|
||||
|
|
@ -261,19 +250,19 @@ class TestDeleteDraftVariableOffloadData:
|
|||
actual_calls = mock_conn.execute.call_args_list
|
||||
|
||||
# First call should be the SELECT query
|
||||
select_call_sql = str(actual_calls[0][0][0])
|
||||
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
|
||||
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
|
||||
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
|
||||
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
|
||||
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
|
||||
|
||||
# Second call should be DELETE upload_files
|
||||
delete_upload_call_sql = str(actual_calls[1][0][0])
|
||||
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
|
||||
assert "DELETE FROM upload_files" in delete_upload_call_sql
|
||||
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
|
||||
|
||||
# Third call should be DELETE workflow_draft_variable_files
|
||||
delete_variable_files_call_sql = str(actual_calls[2][0][0])
|
||||
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
|
||||
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
|
||||
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue