refactor(tasks): migrate document_indexing_task and remove_app_and_related_data_task to SQLAlchemy 2.0 select() API (#34968)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
wdeveloper16 2026-04-12 03:49:56 +02:00 committed by GitHub
parent 0841b4c663
commit f67297688f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 233 additions and 232 deletions

View File

@ -5,6 +5,7 @@ from typing import Any, Protocol
import click
from celery import current_app, shared_task
from sqlalchemy import select
from configs import dify_config
from core.db.session_factory import session_factory
@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
Usage: _document_indexing(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
return
@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
)
except Exception as e:
for document_id in document_ids:
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = IndexingStatus.ERROR
@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# Phase 1: Update status to parsing (short transaction)
with session_factory.create_session() as session, session.begin():
documents = (
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
documents: list[Document] = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
for document in documents:
@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.warning("Dataset %s not found after indexing", dataset_id)
return
@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
documents = (
session.query(Document)
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.all()
documents = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
for document in documents:

View File

@ -6,7 +6,7 @@ from typing import Any, cast
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(session, model_config_id: str):
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
session.execute(
delete(AppModelConfig)
.where(AppModelConfig.id == model_config_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(session, site_id: str):
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(session, mcp_server_id: str):
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
session.execute(
delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(session, api_token_id: str):
# Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1))
if token_obj:
# Invalidate cache before deletion
ApiTokenCache.delete(token_obj.token, token_obj.type)
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
session.execute(
delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(session, installed_app_id: str):
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
session.execute(
delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(session, recommended_app_id: str):
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
session.execute(
delete(RecommendedApp)
.where(RecommendedApp.id == recommended_app_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(session, annotation_hit_history_id: str):
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
session.execute(
delete(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.id == annotation_hit_history_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
)
def del_annotation_setting(session, annotation_setting_id: str):
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
session.execute(
delete(AppAnnotationSetting)
.where(AppAnnotationSetting.id == annotation_setting_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(session, dataset_join_id: str):
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
session.execute(
delete(AppDatasetJoin)
.where(AppDatasetJoin.id == dataset_join_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(session, workflow_id: str):
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(session, workflow_app_log_id: str):
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
session.execute(
delete(WorkflowAppLog)
.where(WorkflowAppLog.id == workflow_app_log_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(session, workflow_archive_log_id: str):
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowArchiveLog)
.where(WorkflowArchiveLog.id == workflow_archive_log_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(session, conversation_id: str):
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
session.execute(
delete(PinnedConversation)
.where(PinnedConversation.conversation_id == conversation_id)
.execution_options(synchronize_session=False)
)
session.execute(
delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False)
)
session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str):
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(session, message_id: str):
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
session.execute(
delete(MessageFeedback)
.where(MessageFeedback.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
session.execute(
delete(MessageAnnotation)
.where(MessageAnnotation.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
session.query(Message).where(Message.id == message_id).delete()
session.execute(
delete(MessageChain)
.where(MessageChain.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.execute(
delete(MessageAgentThought)
.where(MessageAgentThought.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.execute(
delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False)
)
session.execute(
delete(SavedMessage)
.where(SavedMessage.message_id == message_id)
.execution_options(synchronize_session=False)
)
session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(session, tool_provider_id: str):
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowToolProvider)
.where(WorkflowToolProvider.id == tool_provider_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(session, tag_binding_id: str):
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
session.execute(
delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(session, end_user_id: str):
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False))
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(session, trace_app_config_id: str):
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
session.execute(
delete(TraceAppConfig)
.where(TraceAppConfig.id == trace_app_config_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
def _delete_app_triggers(tenant_id: str, app_id: str):
def del_app_trigger(session, trigger_id: str):
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
session.execute(
delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False)
)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def del_plugin_trigger(session, trigger_id: str):
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowPluginTrigger)
.where(WorkflowPluginTrigger.id == trigger_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def del_webhook_trigger(session, trigger_id: str):
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
session.execute(
delete(WorkflowWebhookTrigger)
.where(WorkflowWebhookTrigger.id == trigger_id)
.execution_options(synchronize_session=False)
)
_delete_records(
@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def del_schedule_plan(session, plan_id: str):
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
session.execute(
delete(WorkflowSchedulePlan)
.where(WorkflowSchedulePlan.id == plan_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def del_trigger_log(session, log_id: str):
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
session.execute(
delete(WorkflowTriggerLog)
.where(WorkflowTriggerLog.id == log_id)
.execution_options(synchronize_session=False)
)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",

View File

@ -82,8 +82,8 @@ def mock_db_session():
"""Mock session_factory.create_session() to return a session whose queries use shared test data.
Tests set session._shared_data = {"dataset": <Dataset>, "documents": [<Document>, ...]}
This fixture makes session.query(Dataset).first() return the shared dataset,
and session.query(Document).all()/first() return from the shared documents.
This fixture makes session.scalar(select(Dataset)...) return the shared dataset,
and session.scalars(select(Document)...).all() return the shared documents.
"""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
@ -92,93 +92,68 @@ def mock_db_session():
# Keep a pointer so repeated Document.first() calls iterate across provided docs
session._doc_first_idx = 0
def _query_side_effect(model):
q = MagicMock()
def _get_entity(stmt) -> type | None:
"""Extract the mapped entity class from a SQLAlchemy select statement."""
try:
descs = stmt.column_descriptions
if descs:
return descs[0].get("entity")
except (AttributeError, TypeError):
pass
return None
# Capture filters passed via where(...) so first()/all() can honor them.
q._filters = {}
def _extract_id_from_where(stmt) -> str | None:
"""Return the value bound to the 'id' column in the WHERE clause, if present."""
try:
where = stmt.whereclause
if where is None:
return None
# Both single-clause and AND-clause-list cases
clauses = list(getattr(where, "clauses", [where]))
for clause in clauses:
left = getattr(clause, "left", None)
right = getattr(clause, "right", None)
if left is not None and right is not None:
if getattr(left, "key", None) == "id":
return getattr(right, "value", None)
except Exception:
pass
return None
def _extract_filters(*conds, **kw):
# Support both SQLAlchemy expressions (BinaryExpression) and kwargs
# We only need the simple fields used by production code: id, dataset_id, and id.in_(...)
for cond in conds:
left = getattr(cond, "left", None)
right = getattr(cond, "right", None)
key = None
if left is not None:
key = getattr(left, "key", None) or getattr(left, "name", None)
if not key:
continue
# Right side might be a BindParameter with .value, or a raw value/sequence
val = getattr(right, "value", right)
q._filters[key] = val
# Also accept kwargs (e.g., where(id=...)) just in case
for k, v in kw.items():
q._filters[k] = v
def _where_side_effect(*conds, **kw):
_extract_filters(*conds, **kw)
return q
q.where.side_effect = _where_side_effect
# Dataset queries
if model.__name__ == "Dataset":
def _dataset_first():
ds = session._shared_data.get("dataset")
if not ds:
return None
if "id" in q._filters:
val = q._filters["id"]
if isinstance(val, (list, tuple, set)):
return ds if ds.id in val else None
return ds if ds.id == val else None
return ds
def _dataset_all():
ds = session._shared_data.get("dataset")
if not ds:
return []
first = _dataset_first()
return [first] if first else []
q.first.side_effect = _dataset_first
q.all.side_effect = _dataset_all
return q
# Document queries
if model.__name__ == "Document":
def _apply_doc_filters(docs):
result = list(docs)
for key in ("id", "dataset_id"):
if key in q._filters:
val = q._filters[key]
if isinstance(val, (list, tuple, set)):
result = [d for d in result if getattr(d, key, None) in val]
else:
result = [d for d in result if getattr(d, key, None) == val]
return result
def _docs_all():
def _scalar_side_effect(stmt):
entity = _get_entity(stmt)
if entity is not None:
if entity.__name__ == "Dataset":
return session._shared_data.get("dataset")
elif entity.__name__ == "Document":
docs = session._shared_data.get("documents", [])
return _apply_doc_filters(docs)
if not docs:
return None
# When the WHERE clause filters by id, return the matching document
queried_id = _extract_id_from_where(stmt)
if queried_id:
doc_map = {d.id: d for d in docs}
return doc_map.get(queried_id, docs[0])
return docs[0]
return None
def _docs_first():
docs = _docs_all()
return docs[0] if docs else None
def _scalars_side_effect(stmt):
entity = _get_entity(stmt)
result = MagicMock()
if entity is not None:
if entity.__name__ == "Document":
result.all.return_value = list(session._shared_data.get("documents", []))
elif entity.__name__ == "Dataset":
ds = session._shared_data.get("dataset")
result.all.return_value = [ds] if ds else []
else:
result.all.return_value = []
else:
result.all.return_value = []
return result
q.all.side_effect = _docs_all
q.first.side_effect = _docs_first
return q
# Default fallback
q.first.return_value = None
q.all.return_value = []
return q
session.query.side_effect = _query_side_effect
session.scalar.side_effect = _scalar_side_effect
session.scalars.side_effect = _scalars_side_effect
# Implement session.begin() context manager that commits on exit
session.commit = MagicMock()
@ -638,8 +613,6 @@ class TestProgressTracking:
wrapper = TaskWrapper(data=next_task_data)
mock_redis.rpop.return_value = wrapper.serialize()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -662,7 +635,6 @@ class TestProgressTracking:
"""
# Arrange
mock_redis.rpop.return_value = None # No more tasks
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -780,8 +752,7 @@ class TestErrorHandling:
If the dataset doesn't exist, the task should exit gracefully.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Arrange - dataset is not in _shared_data (None by default), so scalar() returns None
# Act
_document_indexing(dataset_id, document_ids)
@ -806,8 +777,6 @@ class TestErrorHandling:
# Set up rpop to return task once for concurrency check
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Make _document_indexing raise an error
with patch("tasks.document_indexing_task._document_indexing") as mock_indexing:
mock_indexing.side_effect = Exception("Processing failed")
@ -844,7 +813,7 @@ class TestErrorHandling:
# Mock rpop to return tasks one by one
mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -977,7 +946,7 @@ class TestAdvancedScenarios:
# Mock rpop to return tasks up to concurrency limit
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -1070,7 +1039,7 @@ class TestAdvancedScenarios:
# Mock rpop to return tasks in FIFO order
mock_redis.rpop.side_effect = tasks + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -1108,7 +1077,7 @@ class TestAdvancedScenarios:
"""
# Arrange
mock_redis.rpop.return_value = None # Empty queue
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
# Act
@ -1276,7 +1245,7 @@ class TestIntegration:
# First call returns task 2, second call returns None
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@ -1433,7 +1402,7 @@ class TestPerformanceScenarios:
# Mock rpop to return tasks up to concurrency limit
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_db_session._shared_data["dataset"] = mock_dataset
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
@ -1536,10 +1505,8 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test early return when dataset does not exist."""
# Arrange
session = MagicMock()
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = None
session.query.side_effect = lambda model: dataset_query
session = MagicMock()
session.scalar.return_value = None # dataset not found
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
@ -1560,16 +1527,15 @@ class TestDocumentIndexingTaskSummaryFlow:
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.first.return_value = document
session = MagicMock()
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
def _scalar_se(stmt):
entity = stmt.column_descriptions[0].get("entity")
if entity is Dataset:
return dataset
return document
session.scalar.side_effect = _scalar_se
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1643,9 +1609,12 @@ class TestDocumentIndexingTaskSummaryFlow:
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_document_query
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=phase1_docs))
session3.scalar.return_value = dataset
session3.scalars.return_value = MagicMock(
all=MagicMock(return_value=[doc_eligible, doc_skip_form, doc_skip_status])
)
create_session_mock = MagicMock(
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
@ -1704,9 +1673,11 @@ class TestDocumentIndexingTaskSummaryFlow:
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = dataset
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[doc_eligible]))
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1736,21 +1707,14 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test early return when dataset is missing after indexing."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.side_effect = [dataset, None]
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = None # dataset not found on second query
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1770,7 +1734,7 @@ class TestDocumentIndexingTaskSummaryFlow:
_document_indexing("dataset-1", ["doc-1"])
# Assert
session3.query.assert_called()
session3.scalar.assert_called()
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test summary generation skipped when indexing_technique is not high_quality."""
@ -1781,21 +1745,14 @@ class TestDocumentIndexingTaskSummaryFlow:
indexing_technique="economy",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session3.query.side_effect = lambda model: dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = dataset
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1824,19 +1781,12 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test summary generation is skipped when indexing is paused."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
@ -1865,19 +1815,12 @@ class TestDocumentIndexingTaskSummaryFlow:
"""Test generic indexing runner exception is handled."""
# Arrange
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
document_query = MagicMock()
document_query.where.return_value = document_query
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: document_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",
@ -1922,25 +1865,15 @@ class TestDocumentIndexingTaskSummaryFlow:
indexing_technique="high_quality",
summary_index_setting={"enable": True},
)
dataset_query = MagicMock()
dataset_query.where.return_value = dataset_query
dataset_query.first.return_value = dataset
phase1_query = MagicMock()
phase1_query.where.return_value = phase1_query
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
summary_query = MagicMock()
summary_query.where.return_value = summary_query
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
session1 = MagicMock()
session2 = MagicMock()
session2.begin.return_value = nullcontext()
session3 = MagicMock()
session1.query.side_effect = lambda model: dataset_query
session2.query.side_effect = lambda model: phase1_query
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
session1.scalar.return_value = dataset
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
session3.scalar.return_value = dataset
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[_FalseyDocument("missing-doc")]))
monkeypatch.setattr(
"tasks.document_indexing_task.session_factory.create_session",

View File

@ -3,7 +3,6 @@ from unittest.mock import MagicMock, call, patch
import pytest
from libs.archive_storage import ArchiveStorageNotConfiguredError
from models.workflow import WorkflowArchiveLog
from tasks.remove_app_and_related_data_task import (
_delete_app_workflow_archive_logs,
_delete_archived_workflow_run_files,
@ -83,16 +82,11 @@ class TestDeleteWorkflowArchiveLogs:
assert params == {"tenant_id": tenant_id, "app_id": app_id}
assert name == "workflow archive log"
mock_query = MagicMock()
mock_delete_query = MagicMock()
mock_query.where.return_value = mock_delete_query
mock_db.session.query.return_value = mock_query
mock_session = MagicMock()
delete_func(mock_db.session, "log-1")
delete_func(mock_session, "log-1")
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
mock_query.where.assert_called_once()
mock_delete_query.delete.assert_called_once_with(synchronize_session=False)
mock_session.execute.assert_called_once()
class TestDeleteArchivedWorkflowRunFiles: