diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 817249845a..6240f2200f 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): - def del_workflow_archive_log(workflow_archive_log_id: str): - db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( + def del_workflow_archive_log(session, workflow_archive_log_id: str): + session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) @@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: total_files_deleted = 0 while True: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): # Get a batch of draft variable IDs along with their file_ids query_sql = """ SELECT id, file_id FROM workflow_draft_variables diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f46d1bf5db..d020233620 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -10,7 +10,10 @@ from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile -from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch +from tasks.remove_app_and_related_data_task import ( + _delete_draft_variables, + delete_draft_variables_batch, +) @pytest.fixture @@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.return_value = None with session_factory.create_session() as session: draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = session.query(WorkflowDraftVariableFile).count() - upload_files_before = session.query(UploadFile).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.side_effect = [Exception("Storage error"), None] deleted_count = delete_draft_variables_batch(app_id, batch_size=10) @@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration: if app2_obj: session.delete(app2_obj) session.commit() + + +class TestDeleteDraftVariablesSessionCommit: + """Test suite to verify session commit behavior in delete_draft_variables_batch.""" + + @pytest.fixture + def setup_offload_test_data(self, app_and_tenant): + """Create test data with offload files for session commit tests.""" + from core.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now + + tenant, app = app_and_tenant + + with session_factory.create_session() as session: + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file1) + session.add(upload_file2) + session.flush() + + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + session.add(var_file1) + session.add(var_file2) + session.flush() + + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(draft_var1) + session.add(draft_var2) + session.add(draft_var3) + session.commit() + + data = { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } + + yield data + + with session_factory.create_session() as session: + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in data["draft_variables"]]), + (WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]), + (UploadFile, [uf.id for uf in data["upload_files"]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + session.execute(cleanup_query) + session.commit() + + @pytest.fixture + def setup_commit_test_data(self, app_and_tenant): + """Create test data for session commit tests.""" + tenant, app = app_and_tenant + variable_ids: list[str] = [] + + with session_factory.create_session() as session: + variables = [] + for i in range(10): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + variables.append(var) + session.commit() + variable_ids = [v.id for v in variables] + + yield { + "app": app, + "tenant": tenant, + "variable_ids": variable_ids, + } + + with session_factory.create_session() as session: + cleanup_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) + .execution_options(synchronize_session=False) + ) + session.execute(cleanup_query) + session.commit() + + def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data): + """Test that session.begin() is used for automatic transaction management.""" + data = setup_commit_test_data + app_id = data["app"].id + + # Since session.begin() is used, the transaction is automatically committed + # when the with block exits successfully. We verify this by checking that + # data is actually persisted. + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + # Verify all data was deleted (proves transaction was committed) + with session_factory.create_session() as session: + remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + + assert deleted_count == 10 + assert remaining_count == 0 + + def test_data_persisted_after_batch_deletion(self, setup_commit_test_data): + """Test that data is actually persisted to database after batch deletion with commits.""" + data = setup_commit_test_data + app_id = data["app"].id + variable_ids = data["variable_ids"] + + # Verify initial state + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Perform deletion with small batch size to force multiple commits + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + assert deleted_count == 10 + + # Verify all data is deleted in a new session (proves commits worked) + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + # Verify specific IDs are deleted + with session_factory.create_session() as session: + remaining_vars = ( + session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + ) + assert remaining_vars == 0 + + def test_session_commit_with_empty_dataset(self, setup_commit_test_data): + """Test session behavior when deleting from an empty dataset.""" + nonexistent_app_id = str(uuid.uuid4()) + + # Should not raise any errors and should return 0 + deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10) + assert deleted_count == 0 + + def test_session_commit_with_single_batch(self, setup_commit_test_data): + """Test that commit happens correctly when all data fits in a single batch.""" + data = setup_commit_test_data + app_id = data["app"].id + + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Delete all in a single batch + deleted_count = delete_draft_variables_batch(app_id, batch_size=100) + assert deleted_count == 10 + + # Verify data is persisted + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + def test_invalid_batch_size_raises_error(self, setup_commit_test_data): + """Test that invalid batch size raises ValueError.""" + data = setup_commit_test_data + app_id = data["app"].id + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=0) + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=-1) + + @patch("extensions.ext_storage.storage") + def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data): + """Test that session commits correctly when cleaning up offload data.""" + data = setup_offload_test_data + app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + mock_storage.delete.return_value = None + + # Verify initial state + with session_factory.create_session() as session: + draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_before == 3 + assert var_files_before == 2 + assert upload_files_before == 2 + + # Delete variables with offload data + deleted_count = delete_draft_variables_batch(app_id, batch_size=10) + assert deleted_count == 3 + + # Verify all data is persisted (deleted) in new session + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_after == 0 + assert var_files_after == 0 + assert upload_files_after == 0 + + # Verify storage cleanup was called + assert mock_storage.delete.call_count == 2 diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index a14bbb01d0..2b11e42cd5 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs: mock_query.where.return_value = mock_delete_query mock_db.session.query.return_value = mock_query - delete_func("log-1") + delete_func(mock_db.session, "log-1") mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) mock_query.where.assert_called_once()