diff --git a/api/services/account_service.py b/api/services/account_service.py index 35e4a505af..d3893c1207 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -327,6 +327,17 @@ class AccountService: @staticmethod def delete_account(account: Account): """Delete account. This method only adds a task to the queue for deletion.""" + # Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only) + from services.enterprise.account_deletion_sync import sync_account_deletion + + sync_success = sync_account_deletion(account_id=account.id, source="account_deleted") + if not sync_success: + logger.warning( + "Enterprise account deletion sync failed for account %s; proceeding with local deletion.", + account.id, + ) + + # Now proceed with async account deletion delete_account_task.delay(account.id) @staticmethod @@ -1230,6 +1241,19 @@ class TenantService: if dify_config.BILLING_ENABLED: BillingService.clean_billing_info_cache(tenant.id) + # Queue account deletion sync task for enterprise backend to reassign resources (enterprise only) + from services.enterprise.account_deletion_sync import sync_workspace_member_removal + + sync_success = sync_workspace_member_removal( + workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed" + ) + if not sync_success: + logger.warning( + "Enterprise workspace member removal sync failed: workspace_id=%s, member_id=%s", + tenant.id, + account.id, + ) + @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): """Update member role""" diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py new file mode 100644 index 0000000000..c7ff42894d --- /dev/null +++ b/api/services/enterprise/account_deletion_sync.py @@ -0,0 +1,115 @@ +import json +import logging +import uuid +from datetime import UTC, datetime + +from redis import RedisError + +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin + +logger = logging.getLogger(__name__) + +ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue" +ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace" + + +def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Queue an account deletion sync task to Redis. + + Internal helper function. Do not call directly - use the public functions instead. + + Args: + workspace_id: The workspace/tenant ID to sync + member_id: The member/account ID that was removed + source: Source of the sync request (for debugging/tracking) + + Returns: + bool: True if task was queued successfully, False otherwise + """ + try: + task = { + "task_id": str(uuid.uuid4()), + "workspace_id": workspace_id, + "member_id": member_id, + "retry_count": 0, + "created_at": datetime.now(UTC).isoformat(), + "source": source, + "type": ACCOUNT_DELETION_SYNC_TASK_TYPE, + } + + # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP + redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task)) + + logger.info( + "Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s", + workspace_id, + member_id, + task["task_id"], + source, + ) + return True + + except (RedisError, TypeError) as e: + logger.error( + "Failed to queue account deletion sync for workspace %s, member %s: %s", + workspace_id, + member_id, + str(e), + exc_info=True, + ) + # Don't raise - we don't want to fail member deletion if queueing fails + return False + + +def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Sync a single workspace member removal (enterprise only). + + Queues a task for the enterprise backend to reassign resources from the removed member. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + workspace_id: The workspace/tenant ID + member_id: The member/account ID that was removed + source: Source of the sync request (e.g., "workspace_member_removed") + + Returns: + bool: True if task was queued (or skipped in community), False if queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) + + +def sync_account_deletion(account_id: str, *, source: str) -> bool: + """ + Sync full account deletion across all workspaces (enterprise only). + + Fetches all workspace memberships for the account and queues a sync task for each. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + account_id: The account ID being deleted + source: Source of the sync request (e.g., "account_deleted") + + Returns: + bool: True if all tasks were queued (or skipped in community), False if any queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + # Fetch all workspaces the account belongs to + workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() + + # Queue sync task for each workspace + success = True + for join in workspace_joins: + if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source): + success = False + + return success diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 4d4e77a802..a09a6e5c65 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -1016,7 +1016,7 @@ class TestAccountService: def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): """ - Test account deletion (should add task to queue). + Test account deletion (should add task to queue and sync to enterprise). """ fake = Faker() email = fake.email() @@ -1034,10 +1034,18 @@ class TestAccountService: password=password, ) - with patch("services.account_service.delete_account_task") as mock_delete_task: + with ( + patch("services.account_service.delete_account_task") as mock_delete_task, + patch("services.enterprise.account_deletion_sync.sync_account_deletion") as mock_sync, + ): + mock_sync.return_value = True + # Delete account AccountService.delete_account(account) + # Verify sync was called + mock_sync.assert_called_once_with(account_id=account.id, source="account_deleted") + # Verify task was added to queue mock_delete_task.delay.assert_called_once_with(account.id) @@ -1716,7 +1724,7 @@ class TestTenantService: def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): """ - Test successful member removal from tenant. + Test successful member removal from tenant (should sync to enterprise). """ fake = Faker() tenant_name = fake.company() @@ -1751,7 +1759,15 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, role="normal") # Remove member - TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + with patch("services.enterprise.account_deletion_sync.sync_workspace_member_removal") as mock_sync: + mock_sync.return_value = True + + TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + + # Verify sync was called + mock_sync.assert_called_once_with( + workspace_id=tenant.id, member_id=member_account.id, source="workspace_member_removed" + ) # Verify member was removed from extensions.ext_database import db diff --git a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py new file mode 100644 index 0000000000..b66111902c --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py @@ -0,0 +1,276 @@ +"""Unit tests for account deletion synchronization. + +This test module verifies the enterprise account deletion sync functionality, +including Redis queuing, error handling, and community vs enterprise behavior. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +from services.enterprise.account_deletion_sync import ( + _queue_task, + sync_account_deletion, + sync_workspace_member_removal, +) + + +class TestQueueTask: + """Unit tests for the _queue_task helper function.""" + + @pytest.fixture + def mock_redis_client(self): + """Mock redis_client for testing.""" + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + yield mock_redis + + @pytest.fixture + def mock_uuid(self): + """Mock UUID generation for predictable task IDs.""" + with patch("services.enterprise.account_deletion_sync.uuid.uuid4") as mock_uuid_gen: + mock_uuid_gen.return_value = MagicMock(hex="test-task-id-1234") + yield mock_uuid_gen + + def test_queue_task_success(self, mock_redis_client, mock_uuid): + """Test successful task queueing to Redis.""" + # Arrange + workspace_id = "ws-123" + member_id = "member-456" + source = "test_source" + + # Act + result = _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) + + # Assert + assert result is True + mock_redis_client.lpush.assert_called_once() + + # Verify the task payload structure + call_args = mock_redis_client.lpush.call_args[0] + assert call_args[0] == "enterprise:member:sync:queue" + + import json + + task_data = json.loads(call_args[1]) + assert task_data["workspace_id"] == workspace_id + assert task_data["member_id"] == member_id + assert task_data["source"] == source + assert task_data["type"] == "sync_member_deletion_from_workspace" + assert task_data["retry_count"] == 0 + assert "task_id" in task_data + assert "created_at" in task_data + + def test_queue_task_redis_error(self, mock_redis_client, caplog): + """Test handling of Redis connection errors.""" + # Arrange + mock_redis_client.lpush.side_effect = RedisError("Connection failed") + + # Act + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + def test_queue_task_type_error(self, mock_redis_client, caplog): + """Test handling of JSON serialization errors.""" + # Arrange + mock_redis_client.lpush.side_effect = TypeError("Cannot serialize") + + # Act + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + +class TestSyncWorkspaceMemberRemoval: + """Unit tests for sync_workspace_member_removal function.""" + + @pytest.fixture + def mock_queue_task(self): + """Mock _queue_task for testing.""" + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is True.""" + # Arrange + workspace_id = "ws-123" + member_id = "member-456" + source = "workspace_member_removed" + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source=source) + + # Assert + assert result is True + mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source=source) + + def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is False (community edition).""" + # Arrange + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + # Act + result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): + """Test handling of queue task failures.""" + # Arrange + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") + + # Assert + assert result is False + + +class TestSyncAccountDeletion: + """Unit tests for sync_account_deletion function.""" + + @pytest.fixture + def mock_db_session(self): + """Mock database session for testing.""" + with patch("services.enterprise.account_deletion_sync.db.session") as mock_session: + yield mock_session + + @pytest.fixture + def mock_queue_task(self): + """Mock _queue_task for testing.""" + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_account_deletion_enterprise_disabled(self, mock_db_session, mock_queue_task): + """Test sync when ENTERPRISE_ENABLED is False (community edition).""" + # Arrange + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is True + mock_db_session.query.assert_not_called() + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_multiple_workspaces(self, mock_db_session, mock_queue_task): + """Test sync for account with multiple workspace memberships.""" + # Arrange + account_id = "acc-123" + + # Mock workspace joins + mock_join1 = MagicMock() + mock_join1.tenant_id = "tenant-1" + mock_join2 = MagicMock() + mock_join2.tenant_id = "tenant-2" + mock_join3 = MagicMock() + mock_join3.tenant_id = "tenant-3" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] + mock_db_session.query.return_value = mock_query + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + # Assert + assert result is True + assert mock_queue_task.call_count == 3 + + # Verify each workspace was queued + mock_queue_task.assert_any_call(workspace_id="tenant-1", member_id=account_id, source="account_deleted") + mock_queue_task.assert_any_call(workspace_id="tenant-2", member_id=account_id, source="account_deleted") + mock_queue_task.assert_any_call(workspace_id="tenant-3", member_id=account_id, source="account_deleted") + + def test_sync_account_deletion_no_workspaces(self, mock_db_session, mock_queue_task): + """Test sync for account with no workspace memberships.""" + # Arrange + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [] + mock_db_session.query.return_value = mock_query + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_partial_failure(self, mock_db_session, mock_queue_task): + """Test sync when some tasks fail to queue.""" + # Arrange + account_id = "acc-123" + + # Mock workspace joins + mock_join1 = MagicMock() + mock_join1.tenant_id = "tenant-1" + mock_join2 = MagicMock() + mock_join2.tenant_id = "tenant-2" + mock_join3 = MagicMock() + mock_join3.tenant_id = "tenant-3" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] + mock_db_session.query.return_value = mock_query + + # Mock queue_task to fail for second workspace + def queue_side_effect(workspace_id, member_id, source): + return workspace_id != "tenant-2" + + mock_queue_task.side_effect = queue_side_effect + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + # Assert + assert result is False # Should return False if any task fails + assert mock_queue_task.call_count == 3 + + def test_sync_account_deletion_all_failures(self, mock_db_session, mock_queue_task): + """Test sync when all tasks fail to queue.""" + # Arrange + mock_join = MagicMock() + mock_join.tenant_id = "tenant-1" + + mock_query = MagicMock() + mock_query.filter_by.return_value.all.return_value = [mock_join] + mock_db_session.query.return_value = mock_query + + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Act + result = sync_account_deletion(account_id="acc-123", source="account_deleted") + + # Assert + assert result is False + mock_queue_task.assert_called_once()