From f67297688f13674d894729da1ac66364f52c4909 Mon Sep 17 00:00:00 2001 From: wdeveloper16 Date: Sun, 12 Apr 2026 03:49:56 +0200 Subject: [PATCH] refactor(tasks): migrate document_indexing_task and remove_app_and_related_data_task to SQLAlchemy 2.0 select() API (#34968) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/tasks/document_indexing_task.py | 24 +- api/tasks/remove_app_and_related_data_task.py | 154 +++++++--- .../tasks/test_dataset_indexing_task.py | 275 +++++++----------- .../test_remove_app_and_related_data_task.py | 12 +- 4 files changed, 233 insertions(+), 232 deletions(-) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 23a80fa106..31dad7937c 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -5,6 +5,7 @@ from typing import Any, Protocol import click from celery import current_app, shared_task +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): Usage: _document_indexing(dataset_id, document_ids) """ - documents = [] start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) return @@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) except Exception as e: for document_id in document_ids: - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = IndexingStatus.ERROR @@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Phase 1: Update status to parsing (short transaction) with session_factory.create_session() as session, session.begin(): - documents = ( - session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + documents: list[Document] = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: @@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.warning("Dataset %s not found after indexing", dataset_id) return @@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.expire_all() # Check each document's indexing status and trigger summary generation if completed - documents = ( - session.query(Document) - .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) - .all() + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index b1840662ff..72d824b8c1 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,7 +6,7 @@ from typing import Any, cast import click import sqlalchemy as sa from celery import shared_task -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker @@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): def del_model_config(session, model_config_id: str): - session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + session.execute( + delete(AppModelConfig) + .where(AppModelConfig.id == model_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): def del_site(session, site_id: str): - session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) + session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False)) _delete_records( """select id from sites where app_id=:app_id limit 1000""", @@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str): def _delete_app_mcp_servers(tenant_id: str, app_id: str): def del_mcp_server(session, mcp_server_id: str): - session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + session.execute( + delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_mcp_servers where app_id=:app_id limit 1000""", @@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(session, api_token_id: str): # Fetch token details for cache invalidation - token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() + token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1)) if token_obj: # Invalidate cache before deletion ApiTokenCache.delete(token_obj.token, token_obj.type) - session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) + session.execute( + delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", @@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): def del_installed_app(session, installed_app_id: str): - session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + session.execute( + delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(session, recommended_app_id: str): - session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) + session.execute( + delete(RecommendedApp) + .where(RecommendedApp.id == recommended_app_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", @@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(session, annotation_hit_history_id: str): - session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationHitHistory) + .where(AppAnnotationHitHistory.id == annotation_hit_history_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): ) def del_annotation_setting(session, annotation_setting_id: str): - session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationSetting) + .where(AppAnnotationSetting.id == annotation_setting_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): def del_dataset_join(session, dataset_join_id: str): - session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + session.execute( + delete(AppDatasetJoin) + .where(AppDatasetJoin.id == dataset_join_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): def del_workflow(session, workflow_id: str): - session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) + session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False)) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(session, workflow_app_log_id: str): - session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id == workflow_app_log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): def del_workflow_archive_log(session, workflow_archive_log_id: str): - session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowArchiveLog) + .where(WorkflowArchiveLog.id == workflow_archive_log_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(session, conversation_id: str): - session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False + session.execute( + delete(PinnedConversation) + .where(PinnedConversation.conversation_id == conversation_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False) ) - session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str): def _delete_app_messages(tenant_id: str, app_id: str): def del_message(session, message_id: str): - session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageFeedback) + .where(MessageFeedback.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageAnnotation) + .where(MessageAnnotation.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) - session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) - session.query(Message).where(Message.id == message_id).delete() + session.execute( + delete(MessageChain) + .where(MessageChain.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageAgentThought) + .where(MessageAgentThought.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False) + ) + session.execute( + delete(SavedMessage) + .where(SavedMessage.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False)) _delete_records( """select id from messages where app_id=:app_id limit 1000""", @@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(session, tool_provider_id: str): - session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowToolProvider) + .where(WorkflowToolProvider.id == tool_provider_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): def del_tag_binding(session, tag_binding_id: str): - session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + session.execute( + delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): def del_end_user(session, end_user_id: str): - session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) + session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False)) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(session, trace_app_config_id: str): - session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) + session.execute( + delete(TraceAppConfig) + .where(TraceAppConfig.id == trace_app_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", @@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int: def _delete_app_triggers(tenant_id: str, app_id: str): def del_app_trigger(session, trigger_id: str): - session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) + session.execute( + delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str): def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def del_plugin_trigger(session, trigger_id: str): - session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowPluginTrigger) + .where(WorkflowPluginTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def del_webhook_trigger(session, trigger_id: str): - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowWebhookTrigger) + .where(WorkflowWebhookTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def del_schedule_plan(session, plan_id: str): - session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowSchedulePlan) + .where(WorkflowSchedulePlan.id == plan_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): def del_trigger_log(session, log_id: str): - session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowTriggerLog) + .where(WorkflowTriggerLog.id == log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 34e474c921..5dad58b8f1 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -82,8 +82,8 @@ def mock_db_session(): """Mock session_factory.create_session() to return a session whose queries use shared test data. Tests set session._shared_data = {"dataset": , "documents": [, ...]} - This fixture makes session.query(Dataset).first() return the shared dataset, - and session.query(Document).all()/first() return from the shared documents. + This fixture makes session.scalar(select(Dataset)...) return the shared dataset, + and session.scalars(select(Document)...).all() return the shared documents. """ with patch("tasks.document_indexing_task.session_factory") as mock_sf: session = MagicMock() @@ -92,93 +92,68 @@ def mock_db_session(): # Keep a pointer so repeated Document.first() calls iterate across provided docs session._doc_first_idx = 0 - def _query_side_effect(model): - q = MagicMock() + def _get_entity(stmt) -> type | None: + """Extract the mapped entity class from a SQLAlchemy select statement.""" + try: + descs = stmt.column_descriptions + if descs: + return descs[0].get("entity") + except (AttributeError, TypeError): + pass + return None - # Capture filters passed via where(...) so first()/all() can honor them. - q._filters = {} + def _extract_id_from_where(stmt) -> str | None: + """Return the value bound to the 'id' column in the WHERE clause, if present.""" + try: + where = stmt.whereclause + if where is None: + return None + # Both single-clause and AND-clause-list cases + clauses = list(getattr(where, "clauses", [where])) + for clause in clauses: + left = getattr(clause, "left", None) + right = getattr(clause, "right", None) + if left is not None and right is not None: + if getattr(left, "key", None) == "id": + return getattr(right, "value", None) + except Exception: + pass + return None - def _extract_filters(*conds, **kw): - # Support both SQLAlchemy expressions (BinaryExpression) and kwargs - # We only need the simple fields used by production code: id, dataset_id, and id.in_(...) - for cond in conds: - left = getattr(cond, "left", None) - right = getattr(cond, "right", None) - key = None - if left is not None: - key = getattr(left, "key", None) or getattr(left, "name", None) - if not key: - continue - # Right side might be a BindParameter with .value, or a raw value/sequence - val = getattr(right, "value", right) - q._filters[key] = val - # Also accept kwargs (e.g., where(id=...)) just in case - for k, v in kw.items(): - q._filters[k] = v - - def _where_side_effect(*conds, **kw): - _extract_filters(*conds, **kw) - return q - - q.where.side_effect = _where_side_effect - - # Dataset queries - if model.__name__ == "Dataset": - - def _dataset_first(): - ds = session._shared_data.get("dataset") - if not ds: - return None - if "id" in q._filters: - val = q._filters["id"] - if isinstance(val, (list, tuple, set)): - return ds if ds.id in val else None - return ds if ds.id == val else None - return ds - - def _dataset_all(): - ds = session._shared_data.get("dataset") - if not ds: - return [] - first = _dataset_first() - return [first] if first else [] - - q.first.side_effect = _dataset_first - q.all.side_effect = _dataset_all - return q - - # Document queries - if model.__name__ == "Document": - - def _apply_doc_filters(docs): - result = list(docs) - for key in ("id", "dataset_id"): - if key in q._filters: - val = q._filters[key] - if isinstance(val, (list, tuple, set)): - result = [d for d in result if getattr(d, key, None) in val] - else: - result = [d for d in result if getattr(d, key, None) == val] - return result - - def _docs_all(): + def _scalar_side_effect(stmt): + entity = _get_entity(stmt) + if entity is not None: + if entity.__name__ == "Dataset": + return session._shared_data.get("dataset") + elif entity.__name__ == "Document": docs = session._shared_data.get("documents", []) - return _apply_doc_filters(docs) + if not docs: + return None + # When the WHERE clause filters by id, return the matching document + queried_id = _extract_id_from_where(stmt) + if queried_id: + doc_map = {d.id: d for d in docs} + return doc_map.get(queried_id, docs[0]) + return docs[0] + return None - def _docs_first(): - docs = _docs_all() - return docs[0] if docs else None + def _scalars_side_effect(stmt): + entity = _get_entity(stmt) + result = MagicMock() + if entity is not None: + if entity.__name__ == "Document": + result.all.return_value = list(session._shared_data.get("documents", [])) + elif entity.__name__ == "Dataset": + ds = session._shared_data.get("dataset") + result.all.return_value = [ds] if ds else [] + else: + result.all.return_value = [] + else: + result.all.return_value = [] + return result - q.all.side_effect = _docs_all - q.first.side_effect = _docs_first - return q - - # Default fallback - q.first.return_value = None - q.all.return_value = [] - return q - - session.query.side_effect = _query_side_effect + session.scalar.side_effect = _scalar_side_effect + session.scalars.side_effect = _scalars_side_effect # Implement session.begin() context manager that commits on exit session.commit = MagicMock() @@ -638,8 +613,6 @@ class TestProgressTracking: wrapper = TaskWrapper(data=next_task_data) mock_redis.rpop.return_value = wrapper.serialize() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -662,7 +635,6 @@ class TestProgressTracking: """ # Arrange mock_redis.rpop.return_value = None # No more tasks - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -780,8 +752,7 @@ class TestErrorHandling: If the dataset doesn't exist, the task should exit gracefully. """ - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = None + # Arrange - dataset is not in _shared_data (None by default), so scalar() returns None # Act _document_indexing(dataset_id, document_ids) @@ -806,8 +777,6 @@ class TestErrorHandling: # Set up rpop to return task once for concurrency check mock_redis.rpop.side_effect = [wrapper.serialize(), None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - # Make _document_indexing raise an error with patch("tasks.document_indexing_task._document_indexing") as mock_indexing: mock_indexing.side_effect = Exception("Processing failed") @@ -844,7 +813,7 @@ class TestErrorHandling: # Mock rpop to return tasks one by one mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -977,7 +946,7 @@ class TestAdvancedScenarios: # Mock rpop to return tasks up to concurrency limit mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -1070,7 +1039,7 @@ class TestAdvancedScenarios: # Mock rpop to return tasks in FIFO order mock_redis.rpop.side_effect = tasks + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -1108,7 +1077,7 @@ class TestAdvancedScenarios: """ # Arrange mock_redis.rpop.return_value = None # Empty queue - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: # Act @@ -1276,7 +1245,7 @@ class TestIntegration: # First call returns task 2, second call returns None mock_redis.rpop.side_effect = [wrapper.serialize(), None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1433,7 +1402,7 @@ class TestPerformanceScenarios: # Mock rpop to return tasks up to concurrency limit mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -1536,10 +1505,8 @@ class TestDocumentIndexingTaskSummaryFlow: """Test early return when dataset does not exist.""" # Arrange session = MagicMock() - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = None - session.query.side_effect = lambda model: dataset_query + session = MagicMock() + session.scalar.return_value = None # dataset not found create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock) @@ -1560,16 +1527,15 @@ class TestDocumentIndexingTaskSummaryFlow: dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.first.return_value = document - session = MagicMock() - session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query + + def _scalar_se(stmt): + entity = stmt.column_descriptions[0].get("entity") + if entity is Dataset: + return dataset + return document + + session.scalar.side_effect = _scalar_se monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1643,9 +1609,12 @@ class TestDocumentIndexingTaskSummaryFlow: session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: phase1_document_query - session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=phase1_docs)) + session3.scalar.return_value = dataset + session3.scalars.return_value = MagicMock( + all=MagicMock(return_value=[doc_eligible, doc_skip_form, doc_skip_status]) + ) create_session_mock = MagicMock( side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)] @@ -1704,9 +1673,11 @@ class TestDocumentIndexingTaskSummaryFlow: session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: phase1_query - session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query + + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = dataset + session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[doc_eligible])) monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1736,21 +1707,14 @@ class TestDocumentIndexingTaskSummaryFlow: """Test early return when dataset is missing after indexing.""" # Arrange dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.side_effect = [dataset, None] - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query - session3.query.side_effect = lambda model: dataset_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = None # dataset not found on second query monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1770,7 +1734,7 @@ class TestDocumentIndexingTaskSummaryFlow: _document_indexing("dataset-1", ["doc-1"]) # Assert - session3.query.assert_called() + session3.scalar.assert_called() def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None: """Test summary generation skipped when indexing_technique is not high_quality.""" @@ -1781,21 +1745,14 @@ class TestDocumentIndexingTaskSummaryFlow: indexing_technique="economy", summary_index_setting={"enable": True}, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] - session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query - session3.query.side_effect = lambda model: dataset_query + + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = dataset monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1824,19 +1781,12 @@ class TestDocumentIndexingTaskSummaryFlow: """Test summary generation is skipped when indexing is paused.""" # Arrange dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]) monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock) @@ -1865,19 +1815,12 @@ class TestDocumentIndexingTaskSummaryFlow: """Test generic indexing runner exception is handled.""" # Arrange dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1922,25 +1865,15 @@ class TestDocumentIndexingTaskSummaryFlow: indexing_technique="high_quality", summary_index_setting={"enable": True}, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - phase1_query = MagicMock() - phase1_query.where.return_value = phase1_query - phase1_query.all.return_value = [SimpleNamespace(id="doc-1")] - - summary_query = MagicMock() - summary_query.where.return_value = summary_query - summary_query.all.return_value = [_FalseyDocument("missing-doc")] - session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: phase1_query - session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query + + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = dataset + session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[_FalseyDocument("missing-doc")])) monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index 0ed4ca05fa..626d1ee0a8 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, call, patch import pytest from libs.archive_storage import ArchiveStorageNotConfiguredError -from models.workflow import WorkflowArchiveLog from tasks.remove_app_and_related_data_task import ( _delete_app_workflow_archive_logs, _delete_archived_workflow_run_files, @@ -83,16 +82,11 @@ class TestDeleteWorkflowArchiveLogs: assert params == {"tenant_id": tenant_id, "app_id": app_id} assert name == "workflow archive log" - mock_query = MagicMock() - mock_delete_query = MagicMock() - mock_query.where.return_value = mock_delete_query - mock_db.session.query.return_value = mock_query + mock_session = MagicMock() - delete_func(mock_db.session, "log-1") + delete_func(mock_session, "log-1") - mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) - mock_query.where.assert_called_once() - mock_delete_query.delete.assert_called_once_with(synchronize_session=False) + mock_session.execute.assert_called_once() class TestDeleteArchivedWorkflowRunFiles: