mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor(tasks): migrate document_indexing_task and remove_app_and_related_data_task to SQLAlchemy 2.0 select() API (#34968)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
0841b4c663
commit
f67297688f
@ -5,6 +5,7 @@ from typing import Any, Protocol
|
||||
|
||||
import click
|
||||
from celery import current_app, shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
Usage: _document_indexing(dataset_id, document_ids)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
|
||||
return
|
||||
@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
# Phase 1: Update status to parsing (short transaction)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
documents = (
|
||||
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||
documents: list[Document] = list(
|
||||
session.scalars(
|
||||
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
).all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
|
||||
documents = (
|
||||
session.query(Document)
|
||||
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
.all()
|
||||
documents = list(
|
||||
session.scalars(
|
||||
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
).all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, cast
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||
def del_model_config(session, model_config_id: str):
|
||||
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppModelConfig)
|
||||
.where(AppModelConfig.id == model_config_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_model_configs where app_id=:app_id limit 1000""",
|
||||
@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_site(tenant_id: str, app_id: str):
|
||||
def del_site(session, site_id: str):
|
||||
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
|
||||
session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from sites where app_id=:app_id limit 1000""",
|
||||
@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def del_mcp_server(session, mcp_server_id: str):
|
||||
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
|
||||
@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(session, api_token_id: str):
|
||||
# Fetch token details for cache invalidation
|
||||
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
|
||||
token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1))
|
||||
if token_obj:
|
||||
# Invalidate cache before deletion
|
||||
ApiTokenCache.delete(token_obj.token, token_obj.type)
|
||||
|
||||
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from api_tokens where app_id=:app_id limit 1000""",
|
||||
@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||
def del_installed_app(session, installed_app_id: str):
|
||||
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||
def del_recommended_app(session, recommended_app_id: str):
|
||||
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(RecommendedApp)
|
||||
.where(RecommendedApp.id == recommended_app_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from recommended_apps where app_id=:app_id limit 1000""",
|
||||
@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
def del_annotation_hit_history(session, annotation_hit_history_id: str):
|
||||
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(AppAnnotationHitHistory)
|
||||
.where(AppAnnotationHitHistory.id == annotation_hit_history_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
)
|
||||
|
||||
def del_annotation_setting(session, annotation_setting_id: str):
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(AppAnnotationSetting)
|
||||
.where(AppAnnotationSetting.id == annotation_setting_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||
def del_dataset_join(session, dataset_join_id: str):
|
||||
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppDatasetJoin)
|
||||
.where(AppDatasetJoin.id == dataset_join_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
|
||||
@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_workflows(tenant_id: str, app_id: str):
|
||||
def del_workflow(session, workflow_id: str):
|
||||
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
|
||||
session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_app_log(session, workflow_app_log_id: str):
|
||||
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(WorkflowAppLog)
|
||||
.where(WorkflowAppLog.id == workflow_app_log_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
||||
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowArchiveLog)
|
||||
.where(WorkflowArchiveLog.id == workflow_archive_log_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||
def del_conversation(session, conversation_id: str):
|
||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(PinnedConversation)
|
||||
.where(PinnedConversation.conversation_id == conversation_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from conversations where app_id=:app_id limit 1000""",
|
||||
@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str):
|
||||
|
||||
def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
def del_message(session, message_id: str):
|
||||
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(MessageAnnotation)
|
||||
.where(MessageAnnotation.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
|
||||
session.query(Message).where(Message.id == message_id).delete()
|
||||
session.execute(
|
||||
delete(MessageChain)
|
||||
.where(MessageChain.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(
|
||||
delete(SavedMessage)
|
||||
.where(SavedMessage.message_id == message_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from messages where app_id=:app_id limit 1000""",
|
||||
@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||
def del_tool_provider(session, tool_provider_id: str):
|
||||
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.id == tool_provider_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||
def del_tag_binding(session, tag_binding_id: str):
|
||||
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
|
||||
@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_end_users(tenant_id: str, app_id: str):
|
||||
def del_end_user(session, end_user_id: str):
|
||||
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
|
||||
session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False))
|
||||
|
||||
_delete_records(
|
||||
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
||||
def del_trace_app_config(session, trace_app_config_id: str):
|
||||
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(TraceAppConfig)
|
||||
.where(TraceAppConfig.id == trace_app_config_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from trace_app_config where app_id=:app_id limit 1000""",
|
||||
@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
|
||||
|
||||
def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||
def del_app_trigger(session, trigger_id: str):
|
||||
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||
def del_plugin_trigger(session, trigger_id: str):
|
||||
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowPluginTrigger)
|
||||
.where(WorkflowPluginTrigger.id == trigger_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||
def del_webhook_trigger(session, trigger_id: str):
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
|
||||
synchronize_session=False
|
||||
session.execute(
|
||||
delete(WorkflowWebhookTrigger)
|
||||
.where(WorkflowWebhookTrigger.id == trigger_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||
def del_schedule_plan(session, plan_id: str):
|
||||
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(WorkflowSchedulePlan)
|
||||
.where(WorkflowSchedulePlan.id == plan_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
|
||||
def del_trigger_log(session, log_id: str):
|
||||
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
|
||||
session.execute(
|
||||
delete(WorkflowTriggerLog)
|
||||
.where(WorkflowTriggerLog.id == log_id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
|
||||
@ -82,8 +82,8 @@ def mock_db_session():
|
||||
"""Mock session_factory.create_session() to return a session whose queries use shared test data.
|
||||
|
||||
Tests set session._shared_data = {"dataset": <Dataset>, "documents": [<Document>, ...]}
|
||||
This fixture makes session.query(Dataset).first() return the shared dataset,
|
||||
and session.query(Document).all()/first() return from the shared documents.
|
||||
This fixture makes session.scalar(select(Dataset)...) return the shared dataset,
|
||||
and session.scalars(select(Document)...).all() return the shared documents.
|
||||
"""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
@ -92,93 +92,68 @@ def mock_db_session():
|
||||
# Keep a pointer so repeated Document.first() calls iterate across provided docs
|
||||
session._doc_first_idx = 0
|
||||
|
||||
def _query_side_effect(model):
|
||||
q = MagicMock()
|
||||
def _get_entity(stmt) -> type | None:
|
||||
"""Extract the mapped entity class from a SQLAlchemy select statement."""
|
||||
try:
|
||||
descs = stmt.column_descriptions
|
||||
if descs:
|
||||
return descs[0].get("entity")
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
# Capture filters passed via where(...) so first()/all() can honor them.
|
||||
q._filters = {}
|
||||
def _extract_id_from_where(stmt) -> str | None:
|
||||
"""Return the value bound to the 'id' column in the WHERE clause, if present."""
|
||||
try:
|
||||
where = stmt.whereclause
|
||||
if where is None:
|
||||
return None
|
||||
# Both single-clause and AND-clause-list cases
|
||||
clauses = list(getattr(where, "clauses", [where]))
|
||||
for clause in clauses:
|
||||
left = getattr(clause, "left", None)
|
||||
right = getattr(clause, "right", None)
|
||||
if left is not None and right is not None:
|
||||
if getattr(left, "key", None) == "id":
|
||||
return getattr(right, "value", None)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _extract_filters(*conds, **kw):
|
||||
# Support both SQLAlchemy expressions (BinaryExpression) and kwargs
|
||||
# We only need the simple fields used by production code: id, dataset_id, and id.in_(...)
|
||||
for cond in conds:
|
||||
left = getattr(cond, "left", None)
|
||||
right = getattr(cond, "right", None)
|
||||
key = None
|
||||
if left is not None:
|
||||
key = getattr(left, "key", None) or getattr(left, "name", None)
|
||||
if not key:
|
||||
continue
|
||||
# Right side might be a BindParameter with .value, or a raw value/sequence
|
||||
val = getattr(right, "value", right)
|
||||
q._filters[key] = val
|
||||
# Also accept kwargs (e.g., where(id=...)) just in case
|
||||
for k, v in kw.items():
|
||||
q._filters[k] = v
|
||||
|
||||
def _where_side_effect(*conds, **kw):
|
||||
_extract_filters(*conds, **kw)
|
||||
return q
|
||||
|
||||
q.where.side_effect = _where_side_effect
|
||||
|
||||
# Dataset queries
|
||||
if model.__name__ == "Dataset":
|
||||
|
||||
def _dataset_first():
|
||||
ds = session._shared_data.get("dataset")
|
||||
if not ds:
|
||||
return None
|
||||
if "id" in q._filters:
|
||||
val = q._filters["id"]
|
||||
if isinstance(val, (list, tuple, set)):
|
||||
return ds if ds.id in val else None
|
||||
return ds if ds.id == val else None
|
||||
return ds
|
||||
|
||||
def _dataset_all():
|
||||
ds = session._shared_data.get("dataset")
|
||||
if not ds:
|
||||
return []
|
||||
first = _dataset_first()
|
||||
return [first] if first else []
|
||||
|
||||
q.first.side_effect = _dataset_first
|
||||
q.all.side_effect = _dataset_all
|
||||
return q
|
||||
|
||||
# Document queries
|
||||
if model.__name__ == "Document":
|
||||
|
||||
def _apply_doc_filters(docs):
|
||||
result = list(docs)
|
||||
for key in ("id", "dataset_id"):
|
||||
if key in q._filters:
|
||||
val = q._filters[key]
|
||||
if isinstance(val, (list, tuple, set)):
|
||||
result = [d for d in result if getattr(d, key, None) in val]
|
||||
else:
|
||||
result = [d for d in result if getattr(d, key, None) == val]
|
||||
return result
|
||||
|
||||
def _docs_all():
|
||||
def _scalar_side_effect(stmt):
|
||||
entity = _get_entity(stmt)
|
||||
if entity is not None:
|
||||
if entity.__name__ == "Dataset":
|
||||
return session._shared_data.get("dataset")
|
||||
elif entity.__name__ == "Document":
|
||||
docs = session._shared_data.get("documents", [])
|
||||
return _apply_doc_filters(docs)
|
||||
if not docs:
|
||||
return None
|
||||
# When the WHERE clause filters by id, return the matching document
|
||||
queried_id = _extract_id_from_where(stmt)
|
||||
if queried_id:
|
||||
doc_map = {d.id: d for d in docs}
|
||||
return doc_map.get(queried_id, docs[0])
|
||||
return docs[0]
|
||||
return None
|
||||
|
||||
def _docs_first():
|
||||
docs = _docs_all()
|
||||
return docs[0] if docs else None
|
||||
def _scalars_side_effect(stmt):
|
||||
entity = _get_entity(stmt)
|
||||
result = MagicMock()
|
||||
if entity is not None:
|
||||
if entity.__name__ == "Document":
|
||||
result.all.return_value = list(session._shared_data.get("documents", []))
|
||||
elif entity.__name__ == "Dataset":
|
||||
ds = session._shared_data.get("dataset")
|
||||
result.all.return_value = [ds] if ds else []
|
||||
else:
|
||||
result.all.return_value = []
|
||||
else:
|
||||
result.all.return_value = []
|
||||
return result
|
||||
|
||||
q.all.side_effect = _docs_all
|
||||
q.first.side_effect = _docs_first
|
||||
return q
|
||||
|
||||
# Default fallback
|
||||
q.first.return_value = None
|
||||
q.all.return_value = []
|
||||
return q
|
||||
|
||||
session.query.side_effect = _query_side_effect
|
||||
session.scalar.side_effect = _scalar_side_effect
|
||||
session.scalars.side_effect = _scalars_side_effect
|
||||
|
||||
# Implement session.begin() context manager that commits on exit
|
||||
session.commit = MagicMock()
|
||||
@ -638,8 +613,6 @@ class TestProgressTracking:
|
||||
wrapper = TaskWrapper(data=next_task_data)
|
||||
mock_redis.rpop.return_value = wrapper.serialize()
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
|
||||
@ -662,7 +635,6 @@ class TestProgressTracking:
|
||||
"""
|
||||
# Arrange
|
||||
mock_redis.rpop.return_value = None # No more tasks
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@ -780,8 +752,7 @@ class TestErrorHandling:
|
||||
|
||||
If the dataset doesn't exist, the task should exit gracefully.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
# Arrange - dataset is not in _shared_data (None by default), so scalar() returns None
|
||||
|
||||
# Act
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
@ -806,8 +777,6 @@ class TestErrorHandling:
|
||||
# Set up rpop to return task once for concurrency check
|
||||
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Make _document_indexing raise an error
|
||||
with patch("tasks.document_indexing_task._document_indexing") as mock_indexing:
|
||||
mock_indexing.side_effect = Exception("Processing failed")
|
||||
@ -844,7 +813,7 @@ class TestErrorHandling:
|
||||
# Mock rpop to return tasks one by one
|
||||
mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None]
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -977,7 +946,7 @@ class TestAdvancedScenarios:
|
||||
|
||||
# Mock rpop to return tasks up to concurrency limit
|
||||
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -1070,7 +1039,7 @@ class TestAdvancedScenarios:
|
||||
|
||||
# Mock rpop to return tasks in FIFO order
|
||||
mock_redis.rpop.side_effect = tasks + [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -1108,7 +1077,7 @@ class TestAdvancedScenarios:
|
||||
"""
|
||||
# Arrange
|
||||
mock_redis.rpop.return_value = None # Empty queue
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
# Act
|
||||
@ -1276,7 +1245,7 @@ class TestIntegration:
|
||||
# First call returns task 2, second call returns None
|
||||
mock_redis.rpop.side_effect = [wrapper.serialize(), None]
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@ -1433,7 +1402,7 @@ class TestPerformanceScenarios:
|
||||
|
||||
# Mock rpop to return tasks up to concurrency limit
|
||||
mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
|
||||
with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit):
|
||||
with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task:
|
||||
@ -1536,10 +1505,8 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test early return when dataset does not exist."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = None
|
||||
session.query.side_effect = lambda model: dataset_query
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None # dataset not found
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
@ -1560,16 +1527,15 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None)
|
||||
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query
|
||||
|
||||
def _scalar_se(stmt):
|
||||
entity = stmt.column_descriptions[0].get("entity")
|
||||
if entity is Dataset:
|
||||
return dataset
|
||||
return document
|
||||
|
||||
session.scalar.side_effect = _scalar_se
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1643,9 +1609,12 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_document_query
|
||||
session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=phase1_docs))
|
||||
session3.scalar.return_value = dataset
|
||||
session3.scalars.return_value = MagicMock(
|
||||
all=MagicMock(return_value=[doc_eligible, doc_skip_form, doc_skip_status])
|
||||
)
|
||||
|
||||
create_session_mock = MagicMock(
|
||||
side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)]
|
||||
@ -1704,9 +1673,11 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = dataset
|
||||
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[doc_eligible]))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1736,21 +1707,14 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test early return when dataset is missing after indexing."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.side_effect = [dataset, None]
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = None # dataset not found on second query
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1770,7 +1734,7 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
_document_indexing("dataset-1", ["doc-1"])
|
||||
|
||||
# Assert
|
||||
session3.query.assert_called()
|
||||
session3.scalar.assert_called()
|
||||
|
||||
def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test summary generation skipped when indexing_technique is not high_quality."""
|
||||
@ -1781,21 +1745,14 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
indexing_technique="economy",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session3.query.side_effect = lambda model: dataset_query
|
||||
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = dataset
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1824,19 +1781,12 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test summary generation is skipped when indexing is paused."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)])
|
||||
monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock)
|
||||
@ -1865,19 +1815,12 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
"""Test generic indexing runner exception is handled."""
|
||||
# Arrange
|
||||
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
document_query = MagicMock()
|
||||
document_query.where.return_value = document_query
|
||||
document_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: document_query
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
@ -1922,25 +1865,15 @@ class TestDocumentIndexingTaskSummaryFlow:
|
||||
indexing_technique="high_quality",
|
||||
summary_index_setting={"enable": True},
|
||||
)
|
||||
dataset_query = MagicMock()
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.first.return_value = dataset
|
||||
|
||||
phase1_query = MagicMock()
|
||||
phase1_query.where.return_value = phase1_query
|
||||
phase1_query.all.return_value = [SimpleNamespace(id="doc-1")]
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value = summary_query
|
||||
summary_query.all.return_value = [_FalseyDocument("missing-doc")]
|
||||
|
||||
session1 = MagicMock()
|
||||
session2 = MagicMock()
|
||||
session2.begin.return_value = nullcontext()
|
||||
session3 = MagicMock()
|
||||
session1.query.side_effect = lambda model: dataset_query
|
||||
session2.query.side_effect = lambda model: phase1_query
|
||||
session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query
|
||||
|
||||
session1.scalar.return_value = dataset
|
||||
session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")]))
|
||||
session3.scalar.return_value = dataset
|
||||
session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[_FalseyDocument("missing-doc")]))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.document_indexing_task.session_factory.create_session",
|
||||
|
||||
@ -3,7 +3,6 @@ from unittest.mock import MagicMock, call, patch
|
||||
import pytest
|
||||
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError
|
||||
from models.workflow import WorkflowArchiveLog
|
||||
from tasks.remove_app_and_related_data_task import (
|
||||
_delete_app_workflow_archive_logs,
|
||||
_delete_archived_workflow_run_files,
|
||||
@ -83,16 +82,11 @@ class TestDeleteWorkflowArchiveLogs:
|
||||
assert params == {"tenant_id": tenant_id, "app_id": app_id}
|
||||
assert name == "workflow archive log"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_delete_query = MagicMock()
|
||||
mock_query.where.return_value = mock_delete_query
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_session = MagicMock()
|
||||
|
||||
delete_func(mock_db.session, "log-1")
|
||||
delete_func(mock_session, "log-1")
|
||||
|
||||
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_delete_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestDeleteArchivedWorkflowRunFiles:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user