From a808389122f616ea3b89a7070dfa450a267e83b7 Mon Sep 17 00:00:00 2001 From: mahammadasim <135003320+mahammadasim@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:25:18 +0530 Subject: [PATCH] test: add new unit tests for message service utilities, get message, feedback, and retention services. (#33169) --- .../test_messages_clean_service.py | 309 +++++ ...ear_free_plan_expired_workflow_run_logs.py | 499 ++++++++ .../test_delete_archived_workflow_run.py | 216 ++++ .../test_restore_archived_workflow_run.py | 1020 ++++++++++++++++ .../test_api_based_extension_service.py | 421 +++++++ .../services/test_app_dsl_service.py | 913 ++++++++++++++ .../services/test_app_generate_service.py | 815 +++++++++++-- ...est_clear_free_plan_tenant_expired_logs.py | 455 ++++++- .../services/test_conversation_service.py | 1066 ++++++++++++++++- .../services/test_end_user_service.py | 748 +++++++++++- .../unit_tests/services/test_file_service.py | 420 +++++++ .../test_human_input_delivery_test_service.py | 342 ++++-- .../services/test_human_input_service.py | 177 ++- .../services/test_message_service.py | 426 ++++++- 14 files changed, 7598 insertions(+), 229 deletions(-) create mode 100644 api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py create mode 100644 api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py create mode 100644 api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py create mode 100644 api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py create mode 100644 api/tests/unit_tests/services/test_api_based_extension_service.py create mode 100644 api/tests/unit_tests/services/test_app_dsl_service.py create mode 100644 api/tests/unit_tests/services/test_file_service.py diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py new file mode 100644 index 0000000000..a34defeba9 --- /dev/null +++ b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py @@ -0,0 +1,309 @@ +import datetime +import os +from unittest.mock import MagicMock, patch + +import pytest + +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +class TestMessagesCleanService: + @pytest.fixture(autouse=True) + def mock_db_engine(self): + with patch("services.retention.conversation.messages_clean_service.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db.engine + + @pytest.fixture + def mock_db_session(self, mock_db_engine): + with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + yield mock_session + + @pytest.fixture + def mock_policy(self): + policy = MagicMock(spec=BillingDisabledPolicy) + return policy + + def test_run_calls_clean_messages(self, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + with patch.object(service, "_clean_messages_by_time_range") as mock_clean: + mock_clean.return_value = {"total_deleted": 5} + result = service.run() + assert result == {"total_deleted": 5} + mock_clean.assert_called_once() + + def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy): + # Arrange + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + end_before=end_before, + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock( + rowcount=1 + ), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute) + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete messages + MagicMock(all=lambda: []), # next batch empty + ] + + # Reset side_effect to be more robust + # The service calls session.execute for: + # 1. Fetch messages + # 2. Fetch apps + # 3. Batch delete relations (8 calls if IDs exist) + # 4. Delete messages + + mock_returns = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps + ] + # 8 deletes for relations + mock_returns.extend([MagicMock() for _ in range(8)]) + # 1 delete for messages + mock_returns.append(MagicMock(rowcount=1)) + # Final fetch messages (empty) + mock_returns.append(MagicMock(all=lambda: [])) + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] + + # Act + with patch("services.retention.conversation.messages_clean_service.time.sleep"): + stats = service.run() + + # Assert + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 1 + assert stats["batches"] == 2 + + def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy): + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + start_from=start_from, + end_before=end_before, + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: []), # No messages + ] + + stats = service.run() + assert stats["total_messages"] == 0 + + def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy): + # Test pagination with cursor + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + end_before=end_before, + batch_size=1, + ) + + msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0) + msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0) + + mock_returns = [] + # Batch 1 + mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)])) + mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + + # Batch 2 + mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)])) + mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + + # Batch 3 + mock_returns.append(MagicMock(all=lambda: [])) + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified + + with patch("services.retention.conversation.messages_clean_service.time.sleep"): + stats = service.run() + + assert stats["batches"] == 3 + assert stats["total_messages"] == 2 + + def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + dry_run=True, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock(all=lambda: []), # next batch empty + ] + mock_policy.filter_message_ids.return_value = ["msg1"] + + with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample: + mock_sample.return_value = ["msg1"] + stats = service.run() + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 0 # Dry run + mock_sample.assert_called() + + def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: []), # apps NOT found + MagicMock(all=lambda: []), # next batch empty + ] + + stats = service.run() + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 0 + + def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: []), # next batch empty + ] + + # We need to successfully execute line 228 and 229, then return empty at 251. + # line 228: raw_messages = list(session.execute(msg_stmt).all()) + # line 251: app_ids = list({msg.app_id for msg in messages}) + + calls = [] + + def list_side_effect(arg): + calls.append(arg) + if len(calls) == 2: # This is the second call to list() in the loop + return [] + return list(arg) + + with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect): + stats = service.run() + assert stats["batches"] == 2 + assert stats["total_messages"] == 1 + + def test_from_time_range_validation(self, mock_policy): + now = datetime.datetime.now() + # Test start_from >= end_before + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range(mock_policy, now, now) + + # Test batch_size <= 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0) + + def test_from_time_range_success(self, mock_policy): + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 2, 1) + # Mock logger to avoid actual logging if needed, though it's fine + service = MessagesCleanService.from_time_range(mock_policy, start, end) + assert service._start_from == start + assert service._end_before == end + + def test_from_days_validation(self, mock_policy): + # Test days < 0 + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(mock_policy, days=-1) + + # Test batch_size <= 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(mock_policy, days=30, batch_size=0) + + def test_from_days_success(self, mock_policy): + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: + fixed_now = datetime.datetime(2024, 6, 1) + mock_now.return_value = fixed_now + + service = MessagesCleanService.from_days(mock_policy, days=10) + assert service._start_from is None + assert service._end_before == fixed_now - datetime.timedelta(days=10) + + def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock(all=lambda: []), # next batch empty + ] + mock_policy.filter_message_ids.return_value = [] # Policy says NO + + stats = service.run() + assert stats["total_messages"] == 1 + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_batch_delete_message_relations_empty(self, mock_db_session): + MessagesCleanService._batch_delete_message_relations(mock_db_session, []) + mock_db_session.execute.assert_not_called() + + def test_batch_delete_message_relations_with_ids(self, mock_db_session): + MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) + assert mock_db_session.execute.call_count == 8 # 8 tables to clean up + + @patch.dict(os.environ, {"SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL": "500"}) + def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_returns = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + ] + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + mock_returns.append(MagicMock(all=lambda: [])) # next batch empty + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] + + with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: + with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: + mock_uniform.return_value = 300.0 + service.run() + mock_uniform.assert_called_with(0, 500) + mock_sleep.assert_called_with(0.3) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..0013cde79e --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,499 @@ +""" +Unit tests for WorkflowRunCleanup service. +""" + +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +def make_run(tenant_id: str = "t1", run_id: str = "r1", created_at: datetime.datetime | None = None): + run = MagicMock() + run.tenant_id = tenant_id + run.id = run_id + run.created_at = created_at or datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC) + return run + + +@pytest.fixture +def mock_repo(): + return MagicMock() + + +@pytest.fixture +def cleanup(mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + yield WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + +# --------------------------------------------------------------------------- +# Constructor validation +# --------------------------------------------------------------------------- + + +class TestWorkflowRunCleanupInit: + def test_only_start_from_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="both set or both omitted"): + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_only_end_before_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="both set or both omitted"): + WorkflowRunCleanup( + days=30, + batch_size=10, + end_before=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_end_before_not_greater_than_start_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="end_before must be greater than start_from"): + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 6, 1), + end_before=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_equal_start_end_raises(self, mock_repo): + dt = datetime.datetime(2024, 1, 1) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=10, start_from=dt, end_before=dt, workflow_run_repo=mock_repo) + + def test_zero_batch_size_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="batch_size must be greater than 0"): + WorkflowRunCleanup(days=30, batch_size=0, workflow_run_repo=mock_repo) + + def test_negative_batch_size_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=-1, workflow_run_repo=mock_repo) + + def test_valid_window_init(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 7 + cfg.BILLING_ENABLED = False + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 6, 1) + c = WorkflowRunCleanup(days=30, batch_size=5, start_from=start, end_before=end, workflow_run_repo=mock_repo) + assert c.window_start == start + assert c.window_end == end + + +# --------------------------------------------------------------------------- +# _empty_related_counts / _format_related_counts +# --------------------------------------------------------------------------- + + +class TestStaticHelpers: + def test_empty_related_counts(self): + counts = WorkflowRunCleanup._empty_related_counts() + assert counts == { + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + def test_format_related_counts(self): + counts = { + "node_executions": 1, + "offloads": 2, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + } + result = WorkflowRunCleanup._format_related_counts(counts) + assert "node_executions 1" in result + assert "offloads 2" in result + assert "trigger_logs 4" in result + + +# --------------------------------------------------------------------------- +# _expiration_datetime +# --------------------------------------------------------------------------- + + +class TestExpirationDatetime: + def test_negative_returns_none(self, cleanup): + assert cleanup._expiration_datetime("t1", -1) is None + + def test_valid_timestamp(self, cleanup): + ts = int(datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC).timestamp()) + result = cleanup._expiration_datetime("t1", ts) + assert result is not None + assert result.year == 2025 + + def test_overflow_returns_none(self, cleanup): + result = cleanup._expiration_datetime("t1", 2**62) + assert result is None + + +# --------------------------------------------------------------------------- +# _is_within_grace_period +# --------------------------------------------------------------------------- + + +class TestIsWithinGracePeriod: + def test_zero_grace_period_returns_false(self, cleanup): + cleanup.free_plan_grace_period_days = 0 + assert cleanup._is_within_grace_period("t1", {"expiration_date": 9999999999}) is False + + def test_within_grace_period(self, cleanup): + cleanup.free_plan_grace_period_days = 30 + # expired just 1 day ago + expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=1) + ts = int(expired.timestamp()) + assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is True + + def test_outside_grace_period(self, cleanup): + cleanup.free_plan_grace_period_days = 5 + # expired 100 days ago + expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=100) + ts = int(expired.timestamp()) + assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is False + + def test_missing_expiration_date_returns_false(self, cleanup): + cleanup.free_plan_grace_period_days = 30 + assert cleanup._is_within_grace_period("t1", {"expiration_date": -1}) is False + + +# --------------------------------------------------------------------------- +# _get_cleanup_whitelist +# --------------------------------------------------------------------------- + + +class TestGetCleanupWhitelist: + def test_billing_disabled_returns_empty(self, cleanup): + cleanup._cleanup_whitelist = None + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + result = cleanup._get_cleanup_whitelist() + assert result == set() + + def test_billing_enabled_fetches_whitelist(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_expired_subscription_cleanup_whitelist.return_value = ["t1", "t2"] + result = c._get_cleanup_whitelist() + assert result == {"t1", "t2"} + + def test_cached_whitelist_returned(self, cleanup): + cleanup._cleanup_whitelist = {"cached"} + result = cleanup._get_cleanup_whitelist() + assert result == {"cached"} + + def test_billing_service_error_returns_empty(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_expired_subscription_cleanup_whitelist.side_effect = Exception("error") + result = c._get_cleanup_whitelist() + assert result == set() + + +# --------------------------------------------------------------------------- +# _filter_free_tenants +# --------------------------------------------------------------------------- + + +class TestFilterFreeTenants: + def test_billing_disabled_all_tenants_free(self, cleanup): + result = cleanup._filter_free_tenants(["t1", "t2"]) + assert result == {"t1", "t2"} + + def test_empty_tenants_returns_empty(self, cleanup): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = True + result = cleanup._filter_free_tenants([]) + assert result == set() + + def test_whitelisted_tenant_excluded(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = {"t1"} + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + from enums.cloud_plan import CloudPlan + + bs.get_plan_bulk_with_cache.return_value = { + "t1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "t2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + result = c._filter_free_tenants(["t1", "t2"]) + assert "t1" not in result + assert "t2" in result + + def test_paid_tenant_excluded(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.return_value = { + "t1": {"plan": "professional", "expiration_date": -1}, + } + result = c._filter_free_tenants(["t1"]) + assert result == set() + + def test_missing_billing_info_treats_as_non_free(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.return_value = {} + result = c._filter_free_tenants(["t1"]) + assert result == set() + + def test_billing_bulk_error_treats_as_non_free(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.side_effect = Exception("fail") + result = c._filter_free_tenants(["t1"]) + assert result == set() + + +# --------------------------------------------------------------------------- +# run() — delete mode +# --------------------------------------------------------------------------- + + +class TestRunDeleteMode: + def _make_cleanup(self, mock_repo, billing_enabled=False): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = billing_enabled + return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + def test_no_rows_stops_immediately(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + + def test_all_paid_skips_delete(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + c = self._make_cleanup(mock_repo) + # billing disabled -> all free; but let's override _filter_free_tenants to return empty + c._filter_free_tenants = MagicMock(return_value=set()) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + + def test_runs_deleted_successfully(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.delete_runs_with_related.return_value = { + "runs": 1, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.time.sleep"): + c.run() + mock_repo.delete_runs_with_related.assert_called_once() + + def test_delete_exception_reraises(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.delete_runs_with_related.side_effect = RuntimeError("db error") + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + with pytest.raises(RuntimeError): + c.run() + + def test_summary_with_window_start(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 6, 1), + workflow_run_repo=mock_repo, + ) + c.run() + + +# --------------------------------------------------------------------------- +# run() — dry run mode +# --------------------------------------------------------------------------- + + +class TestRunDryRunMode: + def _make_dry_cleanup(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo, dry_run=True) + + def test_dry_run_no_delete_called(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.count_runs_with_related.return_value = { + "node_executions": 2, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 1, + "pauses": 0, + "pause_reasons": 0, + } + c = self._make_dry_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + mock_repo.count_runs_with_related.assert_called_once() + + def test_dry_run_summary_with_window_start(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 6, 1), + workflow_run_repo=mock_repo, + dry_run=True, + ) + c.run() + + def test_dry_run_all_paid_skips_count(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + c = self._make_dry_cleanup(mock_repo) + c._filter_free_tenants = MagicMock(return_value=set()) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.count_runs_with_related.assert_not_called() + + +# --------------------------------------------------------------------------- +# _delete_trigger_logs / _count_trigger_logs +# --------------------------------------------------------------------------- + + +class TestTriggerLogMethods: + def test_delete_trigger_logs(self, cleanup): + session = MagicMock() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository" + ) as RepoClass: + instance = RepoClass.return_value + instance.delete_by_run_ids.return_value = 5 + result = cleanup._delete_trigger_logs(session, ["r1", "r2"]) + assert result == 5 + + def test_count_trigger_logs(self, cleanup): + session = MagicMock() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository" + ) as RepoClass: + instance = RepoClass.return_value + instance.count_by_run_ids.return_value = 3 + result = cleanup._count_trigger_logs(session, ["r1"]) + assert result == 3 + + +# --------------------------------------------------------------------------- +# _count_node_executions / _delete_node_executions +# --------------------------------------------------------------------------- + + +class TestNodeExecutionMethods: + def test_count_node_executions(self, cleanup): + session = MagicMock() + session.get_bind.return_value = MagicMock() + runs = [make_run("t1", "r1")] + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory" + ) as factory: + repo = factory.create_api_workflow_node_execution_repository.return_value + repo.count_by_runs.return_value = (10, 2) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"): + result = cleanup._count_node_executions(session, runs) + assert result == (10, 2) + + def test_delete_node_executions(self, cleanup): + session = MagicMock() + session.get_bind.return_value = MagicMock() + runs = [make_run("t1", "r1")] + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory" + ) as factory: + repo = factory.create_api_workflow_node_execution_repository.return_value + repo.delete_by_runs.return_value = (5, 1) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"): + result = cleanup._delete_node_executions(session, runs) + assert result == (5, 1) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py new file mode 100644 index 0000000000..9fe153c153 --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py @@ -0,0 +1,216 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy.orm import Session + +from models.workflow import WorkflowRun +from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult + + +class TestArchivedWorkflowRunDeletion: + @pytest.fixture + def mock_db(self): + with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db + + @pytest.fixture + def mock_sessionmaker(self): + with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: + mock_session = MagicMock(spec=Session) + mock_sm.return_value.return_value.__enter__.return_value = mock_session + yield mock_sm, mock_session + + @pytest.fixture + def mock_workflow_run_repo(self): + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository" + ) as mock_repo_cls: + mock_repo = MagicMock() + yield mock_repo + + def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + tenant_id = "tenant-456" + + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = run_id + mock_run.tenant_id = tenant_id + mock_session.get.return_value = mock_run + + deletion = ArchivedWorkflowRunDeletion() + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_run_ids.return_value = [run_id] + + with patch.object(deletion, "_delete_run") as mock_delete_run: + expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True) + mock_delete_run.return_value = expected_result + + result = deletion.delete_by_run_id(run_id) + + assert result == expected_result + mock_session.get.assert_called_once_with(WorkflowRun, run_id) + mock_repo.get_archived_run_ids.assert_called_once() + mock_delete_run.assert_called_once_with(mock_run) + + def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + mock_session.get.return_value = None + + deletion = ArchivedWorkflowRunDeletion() + with patch.object(deletion, "_get_workflow_run_repo"): + result = deletion.delete_by_run_id(run_id) + + assert result.success is False + assert "not found" in result.error + assert result.run_id == run_id + + def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = run_id + mock_session.get.return_value = mock_run + + deletion = ArchivedWorkflowRunDeletion() + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_run_ids.return_value = [] + + result = deletion.delete_by_run_id(run_id) + + assert result.success is False + assert "is not archived" in result.error + + def test_delete_batch(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + deletion = ArchivedWorkflowRunDeletion() + + mock_run1 = MagicMock(spec=WorkflowRun) + mock_run1.id = "run-1" + mock_run2 = MagicMock(spec=WorkflowRun) + mock_run2.id = "run-2" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2] + + with patch.object(deletion, "_delete_run") as mock_delete_run: + mock_delete_run.side_effect = [ + DeleteResult(run_id="run-1", tenant_id="t1", success=True), + DeleteResult(run_id="run-2", tenant_id="t1", success=True), + ] + + results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now()) + + assert len(results) == 2 + assert results[0].run_id == "run-1" + assert results[1].run_id == "run-2" + assert mock_delete_run.call_count == 2 + + def test_delete_run_dry_run(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=True) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + mock_run.tenant_id = "tenant-456" + + result = deletion._delete_run(mock_run) + + assert result.success is True + assert result.run_id == "run-123" + + def test_delete_run_success(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=False) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + mock_run.tenant_id = "tenant-456" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1} + + result = deletion._delete_run(mock_run) + + assert result.success is True + assert result.deleted_counts == {"workflow_runs": 1} + + def test_delete_run_exception(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=False) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deletion._delete_run(mock_run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_trigger_logs(self): + mock_session = MagicMock(spec=Session) + run_ids = ["run-1", "run-2"] + + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository" + ) as mock_repo_cls: + mock_repo = MagicMock() + mock_repo_cls.return_value = mock_repo + mock_repo.delete_by_run_ids.return_value = 5 + + count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids) + + assert count == 5 + mock_repo_cls.assert_called_once_with(mock_session) + mock_repo.delete_by_run_ids.assert_called_once_with(run_ids) + + def test_delete_node_executions(self): + mock_session = MagicMock(spec=Session) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-1" + runs = [mock_run] + + with patch( + "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository" + ) as mock_create_repo: + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + mock_repo.delete_by_runs.return_value = (1, 2) + + with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: + result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs) + + assert result == (1, 2) + mock_create_repo.assert_called_once() + mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"]) + + def test_get_workflow_run_repo(self, mock_db): + deletion = ArchivedWorkflowRunDeletion() + + with patch( + "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" + ) as mock_create_repo: + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + + # First call + repo1 = deletion._get_workflow_run_repo() + assert repo1 == mock_repo + assert deletion.workflow_run_repo == mock_repo + + # Second call (should return cached) + repo2 = deletion._get_workflow_run_repo() + assert repo2 == mock_repo + mock_create_repo.assert_called_once() diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py new file mode 100644 index 0000000000..6097bcbd61 --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -0,0 +1,1020 @@ +""" +Comprehensive unit tests for WorkflowRunRestore service. + +This file provides complete test coverage for all WorkflowRunRestore methods. +Tests are organized by functionality and include edge cases, error handling, +and both positive and negative test scenarios. +""" + +import io +import json +import zipfile +from datetime import datetime +from unittest.mock import Mock, create_autospec, patch + +import pytest + +from libs.archive_storage import ArchiveStorageNotConfiguredError +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowArchiveLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from services.retention.workflow_run.restore_archived_workflow_run import ( + SCHEMA_MAPPERS, + TABLE_MODELS, + RestoreResult, + WorkflowRunRestore, +) + + +class WorkflowRunRestoreTestDataFactory: + """ + Factory for creating test data and mock objects. + + Provides reusable methods to create consistent mock objects for testing + workflow run restore operations. + """ + + @staticmethod + def create_workflow_run_mock( + run_id: str = "run-123", + tenant_id: str = "tenant-123", + app_id: str = "app-123", + created_at: datetime | None = None, + **kwargs, + ) -> Mock: + """ + Create a mock WorkflowRun object. + + Args: + run_id: Unique identifier for the workflow run + tenant_id: Tenant/workspace identifier + app_id: Application identifier + created_at: Creation timestamp + **kwargs: Additional attributes to set on the mock + + Returns: + Mock WorkflowRun object with specified attributes + """ + run = create_autospec(WorkflowRun, instance=True) + run.id = run_id + run.tenant_id = tenant_id + run.app_id = app_id + run.created_at = created_at or datetime(2024, 1, 1, 12, 0, 0) + for key, value in kwargs.items(): + setattr(run, key, value) + return run + + @staticmethod + def create_workflow_archive_log_mock( + run_id: str = "run-123", + tenant_id: str = "tenant-123", + app_id: str = "app-123", + created_at: datetime | None = None, + **kwargs, + ) -> Mock: + """ + Create a mock WorkflowArchiveLog object. + + Args: + run_id: Unique identifier for the workflow run + tenant_id: Tenant/workspace identifier + app_id: Application identifier + created_at: Creation timestamp + **kwargs: Additional attributes to set on the mock + + Returns: + Mock WorkflowArchiveLog object with specified attributes + """ + archive_log = create_autospec(WorkflowArchiveLog, instance=True) + archive_log.workflow_run_id = run_id + archive_log.tenant_id = tenant_id + archive_log.app_id = app_id + archive_log.run_created_at = created_at or datetime(2024, 1, 1, 12, 0, 0) + for key, value in kwargs.items(): + setattr(archive_log, key, value) + return archive_log + + @staticmethod + def create_archive_zip_mock( + manifest: dict | None = None, + tables_data: dict[str, list[dict]] | None = None, + ) -> bytes: + """ + Create a mock archive zip file in memory. + + Args: + manifest: Archive manifest data + tables_data: Dictionary mapping table names to list of records + + Returns: + Bytes representing the zip file + """ + if manifest is None: + manifest = { + "schema_version": "1.0", + "tables": { + "workflow_runs": {"row_count": 1}, + "workflow_app_logs": {"row_count": 2}, + }, + } + + if tables_data is None: + tables_data = { + "workflow_runs": [{"id": "run-123", "tenant_id": "tenant-123"}], + "workflow_app_logs": [ + {"id": "log-1", "workflow_run_id": "run-123"}, + {"id": "log-2", "workflow_run_id": "run-123"}, + ], + } + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + zip_file.writestr("manifest.json", json.dumps(manifest)) + for table_name, records in tables_data.items(): + jsonl_data = "\n".join(json.dumps(record) for record in records) + zip_file.writestr(f"{table_name}.jsonl", jsonl_data) + + zip_buffer.seek(0) + return zip_buffer.getvalue() + + +# --------------------------------------------------------------------------- +# Test WorkflowRunRestore Initialization +# --------------------------------------------------------------------------- + + +class TestWorkflowRunRestoreInit: + """Tests for WorkflowRunRestore.__init__ method.""" + + def test_default_initialization(self): + """Service should initialize with default values.""" + restore = WorkflowRunRestore() + assert restore.dry_run is False + assert restore.workers == 1 + assert restore.workflow_run_repo is None + + def test_dry_run_initialization(self): + """Service should respect dry_run flag.""" + restore = WorkflowRunRestore(dry_run=True) + assert restore.dry_run is True + assert restore.workers == 1 + + def test_custom_workers_initialization(self): + """Service should accept custom workers count.""" + restore = WorkflowRunRestore(workers=5) + assert restore.workers == 5 + + def test_invalid_workers_raises_error(self): + """Service should raise ValueError for workers less than 1.""" + with pytest.raises(ValueError, match="workers must be at least 1"): + WorkflowRunRestore(workers=0) + + def test_negative_workers_raises_error(self): + """Service should raise ValueError for negative workers.""" + with pytest.raises(ValueError, match="workers must be at least 1"): + WorkflowRunRestore(workers=-1) + + +# --------------------------------------------------------------------------- +# Test _get_workflow_run_repo Method +# --------------------------------------------------------------------------- + + +class TestGetWorkflowRunRepo: + """Tests for WorkflowRunRestore._get_workflow_run_repo method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.DifyAPIRepositoryFactory") + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + @patch("services.retention.workflow_run.restore_archived_workflow_run.db") + def test_first_call_creates_repo(self, mock_db, mock_sessionmaker, mock_factory): + """First call should create and cache repository.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + mock_repo = Mock() + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + result = restore._get_workflow_run_repo() + + assert result is mock_repo + assert restore.workflow_run_repo is mock_repo + mock_sessionmaker.assert_called_once_with(bind=mock_db.engine, expire_on_commit=False) + mock_factory.create_api_workflow_run_repository.assert_called_once_with(mock_session) + + def test_cached_repo_returned(self): + """Subsequent calls should return cached repository.""" + restore = WorkflowRunRestore() + mock_repo = Mock() + restore.workflow_run_repo = mock_repo + + result = restore._get_workflow_run_repo() + + assert result is mock_repo + + +# --------------------------------------------------------------------------- +# Test _load_manifest_from_zip Method +# --------------------------------------------------------------------------- + + +class TestLoadManifestFromZip: + """Tests for WorkflowRunRestore._load_manifest_from_zip method.""" + + def test_load_valid_manifest(self): + """Should load manifest from valid zip.""" + manifest_data = {"schema_version": "1.0", "tables": {}} + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("manifest.json", json.dumps(manifest_data)) + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + result = WorkflowRunRestore._load_manifest_from_zip(archive) + + assert result == manifest_data + + def test_missing_manifest_raises_error(self): + """Should raise ValueError when manifest.json is missing.""" + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("other.txt", "data") + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + with pytest.raises(ValueError, match="manifest.json missing from archive bundle"): + WorkflowRunRestore._load_manifest_from_zip(archive) + + def test_invalid_json_raises_error(self): + """Should raise ValueError when manifest contains invalid JSON.""" + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("manifest.json", "invalid json") + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + with pytest.raises(json.JSONDecodeError): + WorkflowRunRestore._load_manifest_from_zip(archive) + + +# --------------------------------------------------------------------------- +# Test _get_schema_version Method +# --------------------------------------------------------------------------- + + +class TestGetSchemaVersion: + """Tests for WorkflowRunRestore._get_schema_version method.""" + + def test_valid_schema_version(self): + """Should return valid schema version from manifest.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": "1.0"} + result = restore._get_schema_version(manifest) + assert result == "1.0" + + def test_missing_schema_version_defaults_to_1_0(self): + """Should default to 1.0 when schema_version is missing.""" + restore = WorkflowRunRestore() + manifest = {"tables": {}} + + with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: + result = restore._get_schema_version(manifest) + + assert result == "1.0" + mock_logger.warning.assert_called_once_with("Manifest missing schema_version; defaulting to 1.0") + + def test_unsupported_schema_version_raises_error(self): + """Should raise ValueError for unsupported schema version.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": "2.0"} + + with pytest.raises(ValueError, match="Unsupported schema_version 2.0"): + restore._get_schema_version(manifest) + + def test_numeric_schema_version_converted_to_string(self): + """Should convert numeric schema version to string.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": 1} + + # This should raise ValueError because "1" is not in SCHEMA_MAPPERS (only "1.0" is) + with pytest.raises(ValueError, match="Unsupported schema_version 1"): + restore._get_schema_version(manifest) + + +# --------------------------------------------------------------------------- +# Test _apply_schema_mapping Method +# --------------------------------------------------------------------------- + + +class TestApplySchemaMapping: + """Tests for WorkflowRunRestore._apply_schema_mapping method.""" + + def test_no_mapping_returns_original(self): + """Should return original record when no mapping exists.""" + restore = WorkflowRunRestore() + record = {"id": "test", "name": "test"} + result = restore._apply_schema_mapping("workflow_runs", "1.0", record) + assert result == record + + def test_mapping_applied(self): + """Should apply mapping when it exists.""" + restore = WorkflowRunRestore() + + def test_mapper(record): + return {**record, "mapped": True} + + # Add test mapper to SCHEMA_MAPPERS + original_mappers = SCHEMA_MAPPERS.copy() + SCHEMA_MAPPERS["1.0"]["test_table"] = test_mapper + + try: + record = {"id": "test"} + result = restore._apply_schema_mapping("test_table", "1.0", record) + assert result == {"id": "test", "mapped": True} + finally: + # Restore original mappers + SCHEMA_MAPPERS.clear() + SCHEMA_MAPPERS.update(original_mappers) + + +# --------------------------------------------------------------------------- +# Test _convert_datetime_fields Method +# --------------------------------------------------------------------------- + + +class TestConvertDatetimeFields: + """Tests for WorkflowRunRestore._convert_datetime_fields method.""" + + def test_iso_datetime_conversion(self): + """Should convert ISO datetime strings to datetime objects.""" + restore = WorkflowRunRestore() + + record = {"created_at": "2024-01-01T12:00:00", "name": "test"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["name"] == "test" + + def test_invalid_datetime_ignored(self): + """Should ignore invalid datetime strings.""" + restore = WorkflowRunRestore() + + record = {"created_at": "invalid-date", "name": "test"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert result["created_at"] == "invalid-date" + assert result["name"] == "test" + + def test_non_datetime_columns_unchanged(self): + """Should leave non-datetime columns unchanged.""" + restore = WorkflowRunRestore() + + record = {"id": "test", "tenant_id": "tenant-123"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert result["id"] == "test" + assert result["tenant_id"] == "tenant-123" + + +# --------------------------------------------------------------------------- +# Test _get_model_column_info Method +# --------------------------------------------------------------------------- + + +class TestGetModelColumnInfo: + """Tests for WorkflowRunRestore._get_model_column_info method.""" + + def test_column_info_extraction(self): + """Should extract column information correctly.""" + restore = WorkflowRunRestore() + + column_names, required_columns, non_nullable_with_default = restore._get_model_column_info(WorkflowRun) + + # Check that we get some expected columns + assert "id" in column_names + assert "tenant_id" in column_names + assert "app_id" in column_names + assert "created_at" in column_names + assert "created_by" in column_names + assert "status" in column_names + + # WorkflowRun model has no required columns (all have defaults or are auto-generated) + assert required_columns == set() + + # Check columns with defaults or server defaults + assert "id" in non_nullable_with_default + assert "created_at" in non_nullable_with_default + assert "elapsed_time" in non_nullable_with_default + assert "total_tokens" in non_nullable_with_default + + +# --------------------------------------------------------------------------- +# Test _restore_table_records Method +# --------------------------------------------------------------------------- + + +class TestRestoreTableRecords: + """Tests for WorkflowRunRestore._restore_table_records method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.TABLE_MODELS") + def test_unknown_table_returns_zero(self, mock_table_models): + """Should return 0 for unknown table.""" + restore = WorkflowRunRestore() + mock_table_models.get.return_value = None + + mock_session = Mock() + records = [{"id": "test"}] + + with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: + result = restore._restore_table_records(mock_session, "unknown_table", records, schema_version="1.0") + + assert result == 0 + mock_logger.warning.assert_called_once_with("Unknown table: %s", "unknown_table") + + def test_empty_records_returns_zero(self): + """Should return 0 for empty records list.""" + restore = WorkflowRunRestore() + mock_session = Mock() + + result = restore._restore_table_records(mock_session, "workflow_runs", [], schema_version="1.0") + assert result == 0 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") + @patch("services.retention.workflow_run.restore_archived_workflow_run.cast") + def test_successful_restore(self, mock_cast, mock_pg_insert): + """Should successfully restore records.""" + restore = WorkflowRunRestore() + + # Mock session and execution + mock_session = Mock() + mock_result = Mock() + mock_result.rowcount = 2 + mock_session.execute.return_value = mock_result + mock_cast.return_value = mock_result + + # Mock insert statement + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_pg_insert.return_value = mock_stmt + + records = [{"id": "test1", "tenant_id": "tenant-123"}, {"id": "test2", "tenant_id": "tenant-123"}] + + result = restore._restore_table_records(mock_session, "workflow_runs", records, schema_version="1.0") + + assert result == 2 + mock_session.execute.assert_called_once() + + def test_missing_required_columns_raises_error(self): + """Should raise ValueError for missing required columns.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + # Since WorkflowRun has no required columns, we need to test with a different model + # Let's test with a mock model that has required columns + mock_model = Mock() + + # Mock a required column + required_column = Mock() + required_column.key = "required_field" + required_column.nullable = False + required_column.default = None + required_column.server_default = None + required_column.autoincrement = False + required_column.type = Mock() + + # Mock the __table__ attribute properly + mock_table = Mock() + mock_table.columns = [required_column] + mock_model.__table__ = mock_table + + records = [{"name": "test"}] # Missing required 'required_field' + + with patch.dict(TABLE_MODELS, {"test_table": mock_model}): + with pytest.raises(ValueError, match="Missing required columns for test_table"): + restore._restore_table_records(mock_session, "test_table", records, schema_version="1.0") + + +# --------------------------------------------------------------------------- +# Test _restore_from_run Method +# --------------------------------------------------------------------------- + + +class TestRestoreFromRun: + """Tests for WorkflowRunRestore._restore_from_run method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_archive_storage_not_configured(self, mock_get_storage): + """Should handle ArchiveStorageNotConfiguredError.""" + restore = WorkflowRunRestore() + mock_get_storage.side_effect = ArchiveStorageNotConfiguredError("Storage not configured") + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: Mock()) + + assert result.success is False + assert "Storage not configured" in result.error + assert result.elapsed_time > 0 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_archive_bundle_not_found(self, mock_get_storage): + """Should handle FileNotFoundError when archive bundle is missing.""" + restore = WorkflowRunRestore() + mock_storage = Mock() + mock_storage.get_object.side_effect = FileNotFoundError("Bundle not found") + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: Mock()) + + assert result.success is False + assert "Archive bundle not found" in result.error + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_dry_run_mode(self, mock_get_storage): + """Should handle dry run mode correctly.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Create a proper mock session with context manager support + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + result = restore._restore_from_run(run, session_maker=lambda: mock_session) + + assert result.success is True + assert result.restored_counts["workflow_runs"] == 1 + assert result.restored_counts["workflow_app_logs"] == 2 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + @patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") + @patch("services.retention.workflow_run.restore_archived_workflow_run.cast") + def test_successful_restore(self, mock_cast, mock_pg_insert, mock_get_storage): + """Should successfully restore from archive.""" + restore = WorkflowRunRestore() + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + # Mock session with context manager support + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + def session_maker(): + return mock_session + + # Mock database execution to return integer counts + mock_result_workflow_runs = Mock() + mock_result_workflow_runs.rowcount = 1 + mock_result_app_logs = Mock() + mock_result_app_logs.rowcount = 2 + + # Configure session.execute to return different results based on the table + def mock_execute(stmt): + if "workflow_runs" in str(stmt): + return mock_result_workflow_runs + else: + return mock_result_app_logs + + mock_session.execute.side_effect = mock_execute + mock_cast.return_value = mock_result_workflow_runs + + # Mock insert statement + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_pg_insert.return_value = mock_stmt + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Mock repository methods + with patch.object(restore, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = Mock() + mock_get_repo.return_value = mock_repo + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=session_maker) + + assert result.success is True + assert result.restored_counts["workflow_runs"] == 1 + assert result.restored_counts["workflow_app_logs"] >= 1 # Just check it's restored + mock_session.commit.assert_called_once() + mock_repo.delete_archive_log_by_run_id.assert_called_once_with(mock_session, run.id) + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_invalid_archive_bundle(self, mock_get_storage): + """Should handle invalid archive bundle.""" + restore = WorkflowRunRestore() + + # Mock storage with invalid zip data + mock_storage = Mock() + mock_storage.get_object.return_value = b"invalid zip data" + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Create proper mock session + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: mock_session) + + assert result.success is False + # The error message comes from zipfile.BadZipFile which says "File is not a zip file" + assert "File is not a zip file" in result.error + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_workflow_archive_log_input(self, mock_get_storage): + """Should handle WorkflowArchiveLog input correctly.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + + # Create proper mock session + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + result = restore._restore_from_run(archive_log, session_maker=lambda: mock_session) + + assert result.success is True + assert result.run_id == archive_log.workflow_run_id + assert result.tenant_id == archive_log.tenant_id + + +# --------------------------------------------------------------------------- +# Test restore_batch Method +# --------------------------------------------------------------------------- + + +class TestRestoreBatch: + """Tests for WorkflowRunRestore.restore_batch method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_empty_tenant_ids_returns_empty(self, mock_sessionmaker): + """Should return empty list when tenant_ids is empty list.""" + restore = WorkflowRunRestore() + + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + result = restore.restore_batch( + tenant_ids=[], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert result == [] + + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_successful_batch_restore(self, mock_executor): + """Should successfully restore batch of workflow runs.""" + restore = WorkflowRunRestore(workers=2) + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + # Mock repository and archive logs + mock_repo = Mock() + archive_log1 = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock("run-1") + archive_log2 = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock("run-2") + mock_repo.get_archived_logs_by_time_range.return_value = [archive_log1, archive_log2] + + # Mock restore results + result1 = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={}) + result2 = RestoreResult(run_id="run-2", tenant_id="tenant-1", success=True, restored_counts={}) + + # Mock ThreadPoolExecutor with context manager support + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[result1, result2]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", side_effect=[result1, result2]): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + results = restore.restore_batch( + tenant_ids=["tenant-1"], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert len(results) == 2 + assert results[0].run_id == "run-1" + assert results[1].run_id == "run-2" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_dry_run_batch_restore(self, mock_executor): + """Should handle dry run mode for batch restore.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_logs_by_time_range.return_value = [archive_log] + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={"workflow_runs": 1}) + + # Mock ThreadPoolExecutor with context manager support + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[result]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + results = restore.restore_batch( + tenant_ids=["tenant-1"], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert len(results) == 1 + assert results[0].success is True + + +# --------------------------------------------------------------------------- +# Test restore_by_run_id Method +# --------------------------------------------------------------------------- + + +class TestRestoreByRunId: + """Tests for WorkflowRunRestore.restore_by_run_id method.""" + + def test_archive_log_not_found(self): + """Should handle case when archive log is not found.""" + restore = WorkflowRunRestore() + + mock_repo = Mock() + mock_repo.get_archived_log_by_run_id.return_value = None + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore.restore_by_run_id("nonexistent-run") + + assert result.success is False + assert "not found" in result.error + assert result.run_id == "nonexistent-run" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_successful_restore_by_id(self, mock_sessionmaker): + """Should successfully restore by run ID.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={}) + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + actual_result = restore.restore_by_run_id("run-1") + + assert actual_result.success is True + assert actual_result.run_id == "run-1" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_dry_run_restore_by_id(self, mock_sessionmaker): + """Should handle dry run mode for restore by ID.""" + restore = WorkflowRunRestore(dry_run=True) + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={"workflow_runs": 1}) + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + actual_result = restore.restore_by_run_id("run-1") + + assert actual_result.success is True + assert actual_result.run_id == "run-1" + + +# --------------------------------------------------------------------------- +# Test RestoreResult Dataclass +# --------------------------------------------------------------------------- + + +class TestRestoreResult: + """Tests for RestoreResult dataclass.""" + + def test_restore_result_creation(self): + """Should create RestoreResult with all fields.""" + result = RestoreResult( + run_id="run-123", + tenant_id="tenant-123", + success=True, + restored_counts={"workflow_runs": 1, "workflow_app_logs": 2}, + error=None, + elapsed_time=5.5, + ) + + assert result.run_id == "run-123" + assert result.tenant_id == "tenant-123" + assert result.success is True + assert result.restored_counts == {"workflow_runs": 1, "workflow_app_logs": 2} + assert result.error is None + assert result.elapsed_time == 5.5 + + def test_restore_result_with_error(self): + """Should create RestoreResult with error.""" + result = RestoreResult( + run_id="run-123", + tenant_id="tenant-123", + success=False, + restored_counts={}, + error="Something went wrong", + ) + + assert result.success is False + assert result.error == "Something went wrong" + assert result.restored_counts == {} + assert result.elapsed_time == 0.0 # Default value + + +# --------------------------------------------------------------------------- +# Test Constants and Mappings +# --------------------------------------------------------------------------- + + +class TestConstantsAndMappings: + """Tests for module constants and mappings.""" + + def test_table_models_mapping(self): + """TABLE_MODELS should contain expected table mappings.""" + expected_tables = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, + } + + assert expected_tables == TABLE_MODELS + + def test_schema_mappers_structure(self): + """SCHEMA_MAPPERS should have correct structure.""" + assert isinstance(SCHEMA_MAPPERS, dict) + assert "1.0" in SCHEMA_MAPPERS + assert isinstance(SCHEMA_MAPPERS["1.0"], dict) + + +# --------------------------------------------------------------------------- +# Integration Tests +# --------------------------------------------------------------------------- + + +class TestIntegration: + """Integration tests combining multiple components.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_full_restore_flow(self, mock_executor, mock_get_storage): + """Test complete restore flow with all components.""" + restore = WorkflowRunRestore(workers=1) + + # Mock storage + mock_storage = Mock() + manifest = { + "schema_version": "1.0", + "tables": { + "workflow_runs": {"row_count": 1}, + }, + } + tables_data = { + "workflow_runs": [ + { + "id": "run-123", + "tenant_id": "tenant-123", + "app_id": "app-123", + "created_at": "2024-01-01T12:00:00", + } + ], + } + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock(manifest, tables_data) + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + mock_result = Mock() + mock_result.rowcount = 1 + mock_session.execute.return_value = mock_result + + # Mock repository + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + # Mock ThreadPoolExecutor (not actually used in restore_by_run_id but needed for patch) + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") as mock_insert: + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_insert.return_value = mock_stmt + + with patch("services.retention.workflow_run.restore_archived_workflow_run.cast") as mock_cast: + mock_cast.return_value = mock_result + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + result = restore.restore_by_run_id("run-123") + + assert result.success is True + assert result.restored_counts.get("workflow_runs") == 1 diff --git a/api/tests/unit_tests/services/test_api_based_extension_service.py b/api/tests/unit_tests/services/test_api_based_extension_service.py new file mode 100644 index 0000000000..7f4b5fdaa3 --- /dev/null +++ b/api/tests/unit_tests/services/test_api_based_extension_service.py @@ -0,0 +1,421 @@ +""" +Comprehensive unit tests for services/api_based_extension_service.py + +Covers: +- APIBasedExtensionService.get_all_by_tenant_id +- APIBasedExtensionService.save +- APIBasedExtensionService.delete +- APIBasedExtensionService.get_with_tenant_id +- APIBasedExtensionService._validation (new record & existing record branches) +- APIBasedExtensionService._ping_connection (pong success, wrong response, exception) +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from services.api_based_extension_service import APIBasedExtensionService + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_extension( + *, + id_: str | None = None, + tenant_id: str = "tenant-001", + name: str = "my-ext", + api_endpoint: str = "https://example.com/hook", + api_key: str = "secret-key-123", +) -> MagicMock: + """Return a lightweight mock that mimics APIBasedExtension.""" + ext = MagicMock() + ext.id = id_ + ext.tenant_id = tenant_id + ext.name = name + ext.api_endpoint = api_endpoint + ext.api_key = api_key + return ext + + +# --------------------------------------------------------------------------- +# Tests: get_all_by_tenant_id +# --------------------------------------------------------------------------- + + +class TestGetAllByTenantId: + """Tests for APIBasedExtensionService.get_all_by_tenant_id.""" + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt): + """Each api_key is decrypted and the list is returned.""" + ext1 = _make_extension(id_="id-1", api_key="enc-key-1") + ext2 = _make_extension(id_="id-2", api_key="enc-key-2") + + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [ + ext1, + ext2, + ] + + result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") + + assert result == [ext1, ext2] + assert ext1.api_key == "decrypted-key" + assert ext2.api_key == "decrypted-key" + assert mock_decrypt.call_count == 2 + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt): + """Returns an empty list gracefully when no records exist.""" + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") + + assert result == [] + mock_decrypt.assert_not_called() + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt): + """Verifies the DB is queried with the supplied tenant_id.""" + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz") + + mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz") + + +# --------------------------------------------------------------------------- +# Tests: save +# --------------------------------------------------------------------------- + + +class TestSave: + """Tests for APIBasedExtensionService.save.""" + + @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") + @patch("services.api_based_extension_service.db") + @patch.object(APIBasedExtensionService, "_validation") + def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt): + """Happy path: validation passes, key is encrypted, record is added and committed.""" + ext = _make_extension(id_=None, api_key="plain-key-123") + + result = APIBasedExtensionService.save(ext) + + mock_validation.assert_called_once_with(ext) + mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123") + assert ext.api_key == "encrypted-key" + mock_db.session.add.assert_called_once_with(ext) + mock_db.session.commit.assert_called_once() + assert result is ext + + @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") + @patch("services.api_based_extension_service.db") + @patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty")) + def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt): + """If _validation raises, save should propagate the error without touching the DB.""" + ext = _make_extension(name="") + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(ext) + + mock_db.session.add.assert_not_called() + mock_db.session.commit.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: delete +# --------------------------------------------------------------------------- + + +class TestDelete: + """Tests for APIBasedExtensionService.delete.""" + + @patch("services.api_based_extension_service.db") + def test_delete_removes_record_and_commits(self, mock_db): + """delete() must call session.delete with the extension and then commit.""" + ext = _make_extension(id_="delete-me") + + APIBasedExtensionService.delete(ext) + + mock_db.session.delete.assert_called_once_with(ext) + mock_db.session.commit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Tests: get_with_tenant_id +# --------------------------------------------------------------------------- + + +class TestGetWithTenantId: + """Tests for APIBasedExtensionService.get_with_tenant_id.""" + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt): + """Found extension has its api_key decrypted before being returned.""" + ext = _make_extension(id_="ext-123", api_key="enc-key") + + (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext + + result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123") + + assert result is ext + assert ext.api_key == "decrypted-key" + mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key") + + @patch("services.api_based_extension_service.db") + def test_raises_value_error_when_not_found(self, mock_db): + """Raises ValueError when no matching extension exists.""" + (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None + + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent") + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt): + """Verifies both tenant_id and extension id are used in the query.""" + ext = _make_extension(id_="ext-abc") + chain = mock_db.session.query.return_value + chain.filter_by.return_value.filter_by.return_value.first.return_value = ext + + APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc") + + # First filter_by call uses tenant_id + chain.filter_by.assert_called_once_with(tenant_id="tenant-002") + # Second filter_by call uses id + chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc") + + +# --------------------------------------------------------------------------- +# Tests: _validation (new record — id is falsy) +# --------------------------------------------------------------------------- + + +class TestValidationNewRecord: + """Tests for _validation() with a brand-new record (no id).""" + + def _build_mock_db(self, name_exists: bool = False): + mock_db = MagicMock() + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( + MagicMock() if name_exists else None + ) + return mock_db + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_valid_new_extension_passes(self, mock_db, mock_ping): + """A new record with all valid fields should pass without exceptions.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey") + + # Should not raise + APIBasedExtensionService._validation(ext) + mock_ping.assert_called_once_with(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_is_empty(self, mock_db): + """Empty name raises ValueError.""" + ext = _make_extension(id_=None, name="") + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_is_none(self, mock_db): + """None name raises ValueError.""" + ext = _make_extension(id_=None, name=None) + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_already_exists_for_new_record(self, mock_db): + """A new record whose name already exists raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( + MagicMock() + ) + ext = _make_extension(id_=None, name="duplicate-name") + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_endpoint_is_empty(self, mock_db): + """Empty api_endpoint raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_endpoint="") + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_endpoint_is_none(self, mock_db): + """None api_endpoint raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_endpoint=None) + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_is_empty(self, mock_db): + """Empty api_key raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="") + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_is_none(self, mock_db): + """None api_key raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key=None) + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_too_short(self, mock_db): + """api_key shorter than 5 characters raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="abc") + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_exactly_four_chars(self, mock_db): + """api_key with exactly 4 characters raises ValueError (boundary condition).""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="1234") + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService._validation(ext) + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping): + """api_key with exactly 5 characters should pass (boundary condition).""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="12345") + + # Should not raise + APIBasedExtensionService._validation(ext) + + +# --------------------------------------------------------------------------- +# Tests: _validation (existing record — id is truthy) +# --------------------------------------------------------------------------- + + +class TestValidationExistingRecord: + """Tests for _validation() with an existing record (id is set).""" + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_valid_existing_extension_passes(self, mock_db, mock_ping): + """An existing record whose name is unique (excluding self) should pass.""" + # .where(...).first() → None means no *other* record has that name + ( + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value + ) = None + ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey") + + # Should not raise + APIBasedExtensionService._validation(ext) + mock_ping.assert_called_once_with(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db): + """Existing record cannot use a name already owned by a different record.""" + ( + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value + ) = MagicMock() + ext = _make_extension(id_="existing-id", name="taken-name") + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService._validation(ext) + + +# --------------------------------------------------------------------------- +# Tests: _ping_connection +# --------------------------------------------------------------------------- + + +class TestPingConnection: + """Tests for APIBasedExtensionService._ping_connection.""" + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_successful_ping_returns_pong(self, mock_requestor_class): + """When the endpoint returns {"result": "pong"}, no exception is raised.""" + mock_client = MagicMock() + mock_client.request.return_value = {"result": "pong"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key") + # Should not raise + APIBasedExtensionService._ping_connection(ext) + + mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_wrong_ping_response_raises_value_error(self, mock_requestor_class): + """When the response is not {"result": "pong"}, a ValueError is raised.""" + mock_client = MagicMock() + mock_client.request.return_value = {"result": "error"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_network_exception_wraps_in_value_error(self, mock_requestor_class): + """Any exception raised during request is wrapped in a ValueError.""" + mock_client = MagicMock() + mock_client.request.side_effect = ConnectionError("network failure") + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class): + """Exception raised by the requestor constructor itself is wrapped.""" + mock_requestor_class.side_effect = RuntimeError("bad config") + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_missing_result_key_raises_value_error(self, mock_requestor_class): + """A response dict without a 'result' key does not equal 'pong' → raises.""" + mock_client = MagicMock() + mock_client.request.return_value = {} # no 'result' key + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_uses_ping_extension_point(self, mock_requestor_class): + """The PING extension point is passed to the client.request call.""" + from models.api_based_extension import APIBasedExtensionPoint + + mock_client = MagicMock() + mock_client.request.return_value = {"result": "pong"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + APIBasedExtensionService._ping_connection(ext) + + call_kwargs = mock_client.request.call_args + assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING + assert call_kwargs.kwargs["params"] == {} diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py new file mode 100644 index 0000000000..33d26f4bcb --- /dev/null +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -0,0 +1,913 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import yaml + +from dify_graph.enums import NodeType +from models import Account, AppMode +from models.model import IconType +from services import app_dsl_service +from services.app_dsl_service import ( + AppDslService, + CheckDependenciesPendingData, + ImportMode, + ImportStatus, + PendingData, + _check_version_compatibility, +) + + +class _FakeHttpResponse: + def __init__(self, content: bytes, *, raises: Exception | None = None): + self.content = content + self._raises = raises + + def raise_for_status(self) -> None: + if self._raises is not None: + raise self._raises + + +def _account_mock(*, tenant_id: str = "tenant-1", account_id: str = "account-1") -> MagicMock: + account = MagicMock(spec=Account) + account.current_tenant_id = tenant_id + account.id = account_id + return account + + +def _yaml_dump(data: dict) -> str: + return yaml.safe_dump(data, allow_unicode=True) + + +def _workflow_yaml(*, version: str = app_dsl_service.CURRENT_DSL_VERSION) -> str: + return _yaml_dump( + { + "version": version, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + } + ) + + +def test_check_version_compatibility_invalid_version_returns_failed(): + assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED + + +def test_check_version_compatibility_newer_version_returns_pending(): + assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING + + +def test_check_version_compatibility_major_older_returns_pending(monkeypatch): + monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0") + assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING + + +def test_check_version_compatibility_minor_older_returns_completed_with_warnings(): + assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS + + +def test_check_version_compatibility_equal_returns_completed(): + assert _check_version_compatibility(app_dsl_service.CURRENT_DSL_VERSION) == ImportStatus.COMPLETED + + +def test_import_app_invalid_import_mode_raises_value_error(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Invalid import_mode"): + service.import_app(account=_account_mock(), import_mode="invalid-mode", yaml_content="version: '0.1.0'") + + +def test_import_app_yaml_url_requires_url(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url=None) + assert result.status == ImportStatus.FAILED + assert "yaml_url is required" in result.error + + +def test_import_app_yaml_content_requires_content(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=None) + assert result.status == ImportStatus.FAILED + assert "yaml_content is required" in result.error + + +def test_import_app_yaml_url_fetch_error_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "Error fetching YAML from URL: boom" in result.error + + +def test_import_app_yaml_url_empty_content_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + return _FakeHttpResponse(b"") + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "Empty content" in result.error + + +def test_import_app_yaml_url_file_too_large_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + return _FakeHttpResponse(b"x" * (app_dsl_service.DSL_MAX_SIZE + 1)) + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "File size exceeds" in result.error + + +def test_import_app_yaml_not_mapping_returns_failed(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="[]") + assert result.status == ImportStatus.FAILED + assert "content must be a mapping" in result.error + + +def test_import_app_version_not_str_returns_failed(): + service = AppDslService(MagicMock()) + yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content) + assert result.status == ImportStatus.FAILED + assert "Invalid version type" in result.error + + +def test_import_app_missing_app_data_returns_failed(): + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump({"version": "0.6.0", "kind": "app"}), + ) + assert result.status == ImportStatus.FAILED + assert "Missing app data" in result.error + + +def test_import_app_app_id_not_found_returns_failed(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + session = MagicMock() + session.scalar.return_value = None + service = AppDslService(session) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id="missing-app", + ) + assert result.status == ImportStatus.FAILED + assert result.error == "App not found" + + +def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + existing_app = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value) + + session = MagicMock() + session.scalar.return_value = existing_app + service = AppDslService(session) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id="app-1", + ) + assert result.status == ImportStatus.FAILED + assert "Only workflow or advanced chat apps" in result.error + + +def test_import_app_pending_stores_import_info_in_redis(): + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(version="99.0.0"), + name="n", + description="d", + icon_type="emoji", + icon="i", + icon_background="#000000", + ) + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + + app_dsl_service.redis_client.setex.assert_called_once() + call = app_dsl_service.redis_client.setex.call_args + redis_key = call.args[0] + assert redis_key.startswith(app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX) + + +def test_import_app_completed_uses_declared_dependencies(monkeypatch): + dependencies_payload = [{"id": "langgenius/google", "version": "1.0.0"}] + + plugin_deps = [SimpleNamespace(model_dump=lambda: dependencies_payload[0])] + monkeypatch.setattr( + app_dsl_service.PluginDependency, + "model_validate", + lambda d: plugin_deps[0], + ) + + created_app = SimpleNamespace(id="app-new", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + draft_var_service = MagicMock() + monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump( + { + "version": app_dsl_service.CURRENT_DSL_VERSION, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + "dependencies": dependencies_payload, + } + ), + ) + + assert result.status == ImportStatus.COMPLETED + assert result.app_id == "app-new" + draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-new") + + +@pytest.mark.parametrize("has_workflow", [True, False]) +def test_import_app_legacy_versions_extract_dependencies(monkeypatch, has_workflow: bool): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow_graph", + lambda *_args, **_kwargs: ["from-workflow"], + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["from-model-config"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_latest_dependencies", + lambda deps: [SimpleNamespace(model_dump=lambda: {"dep": deps[0]})], + ) + + created_app = SimpleNamespace(id="app-legacy", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + draft_var_service = MagicMock() + monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service) + + data: dict = { + "version": "0.1.5", + "kind": "app", + "app": {"name": "Legacy", "mode": AppMode.WORKFLOW.value}, + } + if has_workflow: + data["workflow"] = {"graph": {"nodes": []}, "features": {}} + else: + data["model_config"] = {"model": {"provider": "openai"}} + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_yaml_dump(data) + ) + assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS + draft_var_service.delete_workflow_variables.assert_called_once_with(app_id="app-legacy") + + +def test_import_app_yaml_error_returns_failed(monkeypatch): + def bad_safe_load(_content: str): + raise yaml.YAMLError("bad") + + monkeypatch.setattr(app_dsl_service.yaml, "safe_load", bad_safe_load) + + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="x: y") + assert result.status == ImportStatus.FAILED + assert result.error.startswith("Invalid YAML format:") + + +def test_import_app_unexpected_error_returns_failed(monkeypatch): + monkeypatch.setattr( + AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("oops")) + ) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_workflow_yaml() + ) + assert result.status == ImportStatus.FAILED + assert result.error == "oops" + + +def test_confirm_import_expired_returns_failed(): + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "expired" in result.error + + +def test_confirm_import_invalid_pending_data_type_returns_failed(): + app_dsl_service.redis_client.get.return_value = 123 + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "Invalid import information" in result.error + + +def test_confirm_import_success_deletes_redis_key(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + session = MagicMock() + session.scalar.return_value = None + service = AppDslService(session) + + pending = PendingData( + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + name="name", + description="desc", + icon_type="emoji", + icon="🤖", + icon_background="#fff", + app_id=None, + ) + app_dsl_service.redis_client.get.return_value = pending.model_dump_json() + + created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.COMPLETED + assert result.app_id == "confirmed-app" + app_dsl_service.redis_client.delete.assert_called_once() + + +def test_confirm_import_exception_returns_failed(monkeypatch): + app_dsl_service.redis_client.get.return_value = "not-json" + monkeypatch.setattr( + PendingData, "model_validate_json", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad")) + ) + + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert result.error == "bad" + + +def test_check_dependencies_returns_empty_when_no_redis_data(): + service = AppDslService(MagicMock()) + result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + assert result.leaked_dependencies == [] + + +def test_check_dependencies_calls_analysis_service(monkeypatch): + pending = CheckDependenciesPendingData(dependencies=[], app_id="app-1").model_dump_json() + app_dsl_service.redis_client.get.return_value = pending + dep = app_dsl_service.PluginDependency.model_validate( + {"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}} + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [dep], + ) + + service = AppDslService(MagicMock()) + result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + assert len(result.leaked_dependencies) == 1 + + +def test_create_or_update_app_missing_mode_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="loss app mode"): + service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) + + +def test_create_or_update_app_existing_app_updates_fields(monkeypatch): + fixed_now = object() + monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + app = SimpleNamespace( + id="app-1", + tenant_id="tenant-1", + mode=AppMode.WORKFLOW.value, + name="old", + description="old-desc", + icon_type=IconType.EMOJI, + icon="old-icon", + icon_background="#111111", + updated_by=None, + updated_at=None, + app_model_config=None, + ) + service = AppDslService(MagicMock()) + updated = service._create_or_update_app( + app=app, + data={ + "app": {"mode": AppMode.WORKFLOW.value, "name": "yaml-name", "icon_type": IconType.IMAGE, "icon": "X"}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + }, + account=_account_mock(), + name="override-name", + description=None, + icon_background="#222222", + ) + assert updated is app + assert app.name == "override-name" + assert app.icon_type == IconType.IMAGE + assert app.icon == "X" + assert app.icon_background == "#222222" + assert app.updated_at is fixed_now + + +def test_create_or_update_app_new_app_requires_tenant(): + account = _account_mock() + account.current_tenant_id = None + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Current tenant is not set"): + service._create_or_update_app( + app=None, + data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}}, + account=account, + ) + + +def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(monkeypatch): + class DummyApp(SimpleNamespace): + pass + + monkeypatch.setattr(app_dsl_service, "App", DummyApp) + + sent: list[tuple[str, object]] = [] + monkeypatch.setattr(app_dsl_service.app_was_created, "send", lambda app, account: sent.append((app.id, account.id))) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = SimpleNamespace(unique_hash="uh") + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + monkeypatch.setattr( + AppDslService, "decrypt_dataset_id", lambda *_args, **_kwargs: "00000000-0000-0000-0000-000000000000" + ) + + session = MagicMock() + service = AppDslService(session) + deps = [ + app_dsl_service.PluginDependency.model_validate( + {"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}} + ) + ] + data = { + "app": {"mode": AppMode.WORKFLOW.value, "name": "n"}, + "workflow": { + "environment_variables": [{"x": 1}], + "conversation_variables": [{"y": 2}], + "graph": { + "nodes": [ + {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["enc-1", "enc-2"]}}, + ] + }, + "features": {}, + }, + } + + app = service._create_or_update_app(app=None, data=data, account=_account_mock(), dependencies=deps) + + assert app.tenant_id == "tenant-1" + assert sent == [(app.id, "account-1")] + app_dsl_service.redis_client.setex.assert_called() + workflow_service.sync_draft_workflow.assert_called_once() + + passed_graph = workflow_service.sync_draft_workflow.call_args.kwargs["graph"] + dataset_ids = passed_graph["nodes"][0]["data"]["dataset_ids"] + assert dataset_ids == ["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000000"] + + +def test_create_or_update_app_workflow_missing_workflow_data_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Missing workflow data"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.WORKFLOW.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.WORKFLOW.value}}, + account=_account_mock(), + ) + + +def test_create_or_update_app_chat_requires_model_config(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Missing model_config"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.CHAT.value}}, + account=_account_mock(), + ) + + +def test_create_or_update_app_chat_creates_model_config_and_sends_event(monkeypatch): + class DummyModelConfig(SimpleNamespace): + def from_model_config_dict(self, _cfg: dict): + return self + + monkeypatch.setattr(app_dsl_service, "AppModelConfig", DummyModelConfig) + + sent: list[str] = [] + monkeypatch.setattr( + app_dsl_service.app_model_config_was_updated, "send", lambda app, app_model_config: sent.append(app.id) + ) + + session = MagicMock() + service = AppDslService(session) + + app = SimpleNamespace( + id="app-1", + tenant_id="tenant-1", + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ) + service._create_or_update_app( + app=app, + data={"app": {"mode": AppMode.CHAT.value}, "model_config": {"model": {"provider": "openai"}}}, + account=_account_mock(), + ) + + assert app.app_model_config_id is not None + assert sent == ["app-1"] + session.add.assert_called() + + +def test_create_or_update_app_invalid_mode_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Invalid app mode"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.RAG_PIPELINE.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.RAG_PIPELINE.value}}, + account=_account_mock(), + ) + + +def test_export_dsl_delegates_by_mode(monkeypatch): + workflow_calls: list[bool] = [] + model_calls: list[bool] = [] + monkeypatch.setattr(AppDslService, "_append_workflow_export_data", lambda **_kwargs: workflow_calls.append(True)) + monkeypatch.setattr( + AppDslService, "_append_model_config_export_data", lambda *_args, **_kwargs: model_calls.append(True) + ) + + workflow_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id="tenant-1", + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + AppDslService.export_dsl(workflow_app) + assert workflow_calls == [True] + + chat_app = SimpleNamespace( + mode=AppMode.CHAT.value, + tenant_id="tenant-1", + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}), + ) + AppDslService.export_dsl(chat_app) + assert model_calls == [True] + + +def test_append_workflow_export_data_filters_and_overrides(monkeypatch): + workflow_dict = { + "graph": { + "nodes": [ + {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["d1", "d2"]}}, + {"data": {"type": NodeType.TOOL, "credential_id": "secret"}}, + { + "data": { + "type": NodeType.AGENT, + "agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}}, + } + }, + {"data": {"type": NodeType.TRIGGER_SCHEDULE.value, "config": {"x": 1}}}, + {"data": {"type": NodeType.TRIGGER_WEBHOOK.value, "webhook_url": "x", "webhook_debug_url": "y"}}, + {"data": {"type": NodeType.TRIGGER_PLUGIN.value, "subscription_id": "s"}}, + ] + } + } + + workflow = SimpleNamespace(to_dict=lambda *, include_secret: workflow_dict) + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + AppDslService, "encrypt_dataset_id", lambda *, dataset_id, tenant_id: f"enc:{tenant_id}:{dataset_id}" + ) + monkeypatch.setattr( + TriggerScheduleNode := app_dsl_service.TriggerScheduleNode, + "get_default_config", + lambda: {"config": {"default": True}}, + ) + monkeypatch.setattr(AppDslService, "_extract_dependencies_from_workflow", lambda *_args, **_kwargs: ["dep-1"]) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]}) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + export_data: dict = {} + AppDslService._append_workflow_export_data( + export_data=export_data, + app_model=SimpleNamespace(tenant_id="tenant-1"), + include_secret=False, + workflow_id=None, + ) + + nodes = export_data["workflow"]["graph"]["nodes"] + assert nodes[0]["data"]["dataset_ids"] == ["enc:tenant-1:d1", "enc:tenant-1:d2"] + assert "credential_id" not in nodes[1]["data"] + assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0] + assert nodes[3]["data"]["config"] == {"default": True} + assert nodes[4]["data"]["webhook_url"] == "" + assert nodes[4]["data"]["webhook_debug_url"] == "" + assert nodes[5]["data"]["subscription_id"] == "" + assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}] + + +def test_append_workflow_export_data_missing_workflow_raises(monkeypatch): + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + with pytest.raises(ValueError, match="Missing draft workflow configuration"): + AppDslService._append_workflow_export_data( + export_data={}, + app_model=SimpleNamespace(tenant_id="tenant-1"), + include_secret=False, + workflow_id=None, + ) + + +def test_append_model_config_export_data_filters_credential_id(monkeypatch): + monkeypatch.setattr(AppDslService, "_extract_dependencies_from_model_config", lambda *_args, **_kwargs: ["dep-1"]) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]}) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}}) + app_model = SimpleNamespace(tenant_id="tenant-1", app_model_config=app_model_config) + export_data: dict = {} + + AppDslService._append_model_config_export_data(export_data, app_model) + assert export_data["model_config"]["agent_mode"]["tools"] == [{}] + assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}] + + +def test_append_model_config_export_data_requires_app_config(): + with pytest.raises(ValueError, match="Missing app configuration"): + AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None)) + + +def test_extract_dependencies_from_workflow_graph_covers_all_node_types(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + + monkeypatch.setattr(app_dsl_service.ToolNodeData, "model_validate", lambda _d: SimpleNamespace(provider_id="p1")) + monkeypatch.setattr( + app_dsl_service.LLMNodeData, "model_validate", lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m1")) + ) + monkeypatch.setattr( + app_dsl_service.QuestionClassifierNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m2")), + ) + monkeypatch.setattr( + app_dsl_service.ParameterExtractorNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m3")), + ) + + def kr_validate(_d): + return SimpleNamespace( + retrieval_mode="multiple", + multiple_retrieval_config=SimpleNamespace( + reranking_mode="weighted_score", + weights=SimpleNamespace(vector_setting=SimpleNamespace(embedding_provider_name="m4")), + reranking_model=None, + ), + single_retrieval_config=None, + ) + + monkeypatch.setattr(app_dsl_service.KnowledgeRetrievalNodeData, "model_validate", kr_validate) + + graph = { + "nodes": [ + {"data": {"type": NodeType.TOOL}}, + {"data": {"type": NodeType.LLM}}, + {"data": {"type": NodeType.QUESTION_CLASSIFIER}}, + {"data": {"type": NodeType.PARAMETER_EXTRACTOR}}, + {"data": {"type": NodeType.KNOWLEDGE_RETRIEVAL}}, + {"data": {"type": "unknown"}}, + ] + } + + deps = AppDslService._extract_dependencies_from_workflow_graph(graph) + assert deps == ["tool:p1", "model:m1", "model:m2", "model:m3", "model:m4"] + + +def test_extract_dependencies_from_workflow_graph_handles_exceptions(monkeypatch): + monkeypatch.setattr( + app_dsl_service.ToolNodeData, "model_validate", lambda _d: (_ for _ in ()).throw(ValueError("bad")) + ) + deps = AppDslService._extract_dependencies_from_workflow_graph({"nodes": [{"data": {"type": NodeType.TOOL}}]}) + assert deps == [] + + +def test_extract_dependencies_from_model_config_parses_providers(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + + deps = AppDslService._extract_dependencies_from_model_config( + { + "model": {"provider": "p1"}, + "dataset_configs": { + "datasets": {"datasets": [{"reranking_model": {"reranking_provider_name": {"provider": "p2"}}}]} + }, + "agent_mode": {"tools": [{"provider_id": "t1"}]}, + } + ) + assert deps == ["model:p1", "model:p2", "tool:t1"] + + +def test_extract_dependencies_from_model_config_handles_exceptions(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda _p: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_model_config({"model": {"provider": "p1"}}) + assert deps == [] + + +def test_get_leaked_dependencies_empty_returns_empty(): + assert AppDslService.get_leaked_dependencies("tenant-1", []) == [] + + +def test_get_leaked_dependencies_delegates(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [SimpleNamespace(tenant_id=tenant_id, deps=dependencies)], + ) + res = AppDslService.get_leaked_dependencies("tenant-1", [SimpleNamespace(id="x")]) + assert len(res) == 1 + + +def test_encrypt_decrypt_dataset_id_respects_config(monkeypatch): + tenant_id = "tenant-1" + dataset_uuid = "00000000-0000-0000-0000-000000000000" + + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", False) + assert AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) == dataset_uuid + + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + encrypted = AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) + assert encrypted != dataset_uuid + assert base64.b64decode(encrypted.encode()) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=tenant_id) == dataset_uuid + + +def test_decrypt_dataset_id_returns_plain_uuid_unchanged(): + value = "00000000-0000-0000-0000-000000000000" + assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id="tenant-1") == value + + +def test_decrypt_dataset_id_returns_none_on_invalid_data(monkeypatch): + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id="tenant-1") is None + + +def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(monkeypatch): + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + encrypted = AppDslService.encrypt_dataset_id(dataset_id="not-a-uuid", tenant_id="tenant-1") + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id="tenant-1") is None + + +def test_is_valid_uuid_handles_bad_inputs(): + assert AppDslService._is_valid_uuid("00000000-0000-0000-0000-000000000000") is True + assert AppDslService._is_valid_uuid("nope") is False diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index 47b759bc7d..c2b430c551 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -1,14 +1,50 @@ +""" +Comprehensive unit tests for services.app_generate_service.AppGenerateService. + +Covers: + - _build_streaming_task_on_subscribe (streams / pubsub / exception / idempotency) + - generate (COMPLETION / AGENT_CHAT / CHAT / ADVANCED_CHAT / WORKFLOW / invalid mode, + streaming & blocking, billing, quota-refund-on-error, rate_limit.exit) + - _get_max_active_requests (all limit combos) + - generate_single_iteration (ADVANCED_CHAT / WORKFLOW / invalid mode) + - generate_single_loop (ADVANCED_CHAT / WORKFLOW / invalid mode) + - generate_more_like_this + - _get_workflow (debugger / non-debugger / specific id / invalid format / not found) + - get_response_generator (ended / non-ended workflow run) +""" + +import threading +import time +import uuid +from contextlib import contextmanager from unittest.mock import MagicMock -import services.app_generate_service as app_generate_service_module +import pytest + +import services.app_generate_service as ags_module +from core.app.entities.app_invoke_entities import InvokeFrom from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError +# --------------------------------------------------------------------------- +# Helpers / Fakes +# --------------------------------------------------------------------------- class _DummyRateLimit: + """Minimal stand-in for RateLimit that never touches Redis.""" + + _instance_dict: dict[str, "_DummyRateLimit"] = {} + + def __new__(cls, client_id: str, max_active_requests: int): + # avoid singleton caching across tests + instance = object.__new__(cls) + return instance + def __init__(self, client_id: str, max_active_requests: int) -> None: self.client_id = client_id self.max_active_requests = max_active_requests + self._exited: list[str] = [] @staticmethod def gen_request_key() -> str: @@ -18,101 +54,720 @@ class _DummyRateLimit: return request_id or "dummy-request-id" def exit(self, request_id: str) -> None: - return None + self._exited.append(request_id) def generate(self, generator, request_id: str): return generator -def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch): - monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) - mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) +def _make_app(mode: AppMode | str, *, max_active_requests: int = 0, is_agent: bool = False) -> MagicMock: + app = MagicMock() + app.mode = mode + app.id = "app-id" + app.tenant_id = "tenant-id" + app.max_active_requests = max_active_requests + app.is_agent = is_agent + return app - workflow = MagicMock() - workflow.id = "workflow-id" - workflow.created_by = "owner-id" - - mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) - - generator_spy = mocker.patch( - "services.app_generate_service.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) - - app_model = MagicMock() - app_model.mode = AppMode.WORKFLOW - app_model.id = "app-id" - app_model.tenant_id = "tenant-id" - app_model.max_active_requests = 0 - app_model.is_agent = False +def _make_user() -> MagicMock: user = MagicMock() user.id = "user-id" - - result = AppGenerateService.generate( - app_model=app_model, - user=user, - args={"inputs": {"k": "v"}}, - invoke_from=MagicMock(), - streaming=False, - ) - - assert result == {"result": "ok"} - - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert pause_state_config is not None - assert pause_state_config.state_owner_user_id == "owner-id" + return user -def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch): - """ - Regression test: ADVANCED_CHAT in blocking mode should return a plain dict - (non-streaming), and must not go through the async retrieve_events path. - Keeps behavior consistent with WORKFLOW blocking branch. - """ - # Disable billing and stub RateLimit to a no-op that just passes values through - monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) - mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) - - # Arrange a fake workflow and wire AppGenerateService._get_workflow to return it +def _make_workflow(*, workflow_id: str = "workflow-id", created_by: str = "owner-id") -> MagicMock: workflow = MagicMock() - workflow.id = "workflow-id" - mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + workflow.id = workflow_id + workflow.created_by = created_by + return workflow - # Spy on the streaming retrieval path to ensure it's NOT called - retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events") - # Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False - generate_spy = mocker.patch( - "services.app_generate_service.AdvancedChatAppGenerator.generate", - return_value={"result": "ok"}, - ) +@contextmanager +def _noop_rate_limit_context(rate_limit, request_id): + """Drop-in replacement for rate_limit_context that doesn't touch Redis.""" + yield - # Minimal app model for ADVANCED_CHAT - app_model = MagicMock() - app_model.mode = AppMode.ADVANCED_CHAT - app_model.id = "app-id" - app_model.tenant_id = "tenant-id" - app_model.max_active_requests = 0 - app_model.is_agent = False - user = MagicMock() - user.id = "user-id" +# --------------------------------------------------------------------------- +# _build_streaming_task_on_subscribe +# --------------------------------------------------------------------------- +class TestBuildStreamingTaskOnSubscribe: + """Tests for AppGenerateService._build_streaming_task_on_subscribe.""" - # Must include query and inputs for AdvancedChatAppGenerator - args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}} + def test_streams_mode_starts_immediately(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + # task started immediately during build + assert called == [1] + # calling the returned callback is idempotent + cb() + assert called == [1] # not called again - # Act: call service with streaming=False (blocking mode) - result = AppGenerateService.generate( - app_model=app_model, - user=user, - args=args, - invoke_from=MagicMock(), - streaming=False, - ) + def test_pubsub_mode_starts_on_subscribe(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) # large to prevent timer + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + assert called == [] + cb() + assert called == [1] + # second call is idempotent + cb() + assert called == [1] - # Assert: returns the dict from generate(), and did not call retrieve_events() - assert result == {"result": "ok"} - assert generate_spy.call_args.kwargs.get("streaming") is False - retrieve_spy.assert_not_called() + def test_sharded_mode_starts_on_subscribe(self, monkeypatch): + """sharded is treated like pubsub (i.e. not 'streams').""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + assert called == [] + cb() + assert called == [1] + + def test_pubsub_fallback_timer_fires(self, monkeypatch): + """When nobody subscribes fast enough the fallback timer fires.""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 50) # 50 ms + called = [] + _cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + time.sleep(0.2) # give the timer time to fire + assert called == [1] + + def test_exception_in_start_task_returns_false(self, monkeypatch): + """When start_task raises, _try_start returns False and next call retries.""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + call_count = 0 + + def _bad(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("boom") + + cb = AppGenerateService._build_streaming_task_on_subscribe(_bad) + # first call inside build raised, but is caught; second call via cb succeeds + assert call_count == 1 + cb() + assert call_count == 2 + + def test_concurrent_subscribe_only_starts_once(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) + call_count = 0 + + def _inc(): + nonlocal call_count + call_count += 1 + + cb = AppGenerateService._build_streaming_task_on_subscribe(_inc) + threads = [threading.Thread(target=cb) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# _get_max_active_requests +# --------------------------------------------------------------------------- +class TestGetMaxActiveRequests: + def test_both_zero_returns_zero(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 0 + + def test_app_limit_only(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=5) + assert AppGenerateService._get_max_active_requests(app) == 5 + + def test_config_limit_only(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 10) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 10 + + def test_both_non_zero_returns_min(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 20) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=5) + assert AppGenerateService._get_max_active_requests(app) == 5 + + def test_default_active_requests_used_when_app_has_none(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 15) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 15 + + +# --------------------------------------------------------------------------- +# generate – every AppMode branch +# --------------------------------------------------------------------------- +class TestGenerate: + """Tests for AppGenerateService.generate covering each mode.""" + + @pytest.fixture(autouse=True) + def _common(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + # Prevent AppExecutionParams.new from touching real models via isinstance + mocker.patch( + "services.app_generate_service.rate_limit_context", + _noop_rate_limit_context, + ) + + # -- COMPLETION --------------------------------------------------------- + def test_completion_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"result": "ok"}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + result = AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "ok"} + gen_spy.assert_called_once() + + # -- AGENT_CHAT via mode ------------------------------------------------ + def test_agent_chat_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.generate", + return_value={"result": "agent"}, + ) + mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + result = AppGenerateService.generate( + app_model=_make_app(AppMode.AGENT_CHAT), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "agent"} + gen_spy.assert_called_once() + + # -- AGENT_CHAT via is_agent flag (non-AGENT_CHAT mode) ----------------- + def test_agent_via_is_agent_flag(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.generate", + return_value={"result": "agent-via-flag"}, + ) + mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + app = _make_app(AppMode.CHAT, is_agent=True) + result = AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "agent-via-flag"} + gen_spy.assert_called_once() + + # -- CHAT --------------------------------------------------------------- + def test_chat_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.ChatAppGenerator.generate", + return_value={"result": "chat"}, + ) + mocker.patch( + "services.app_generate_service.ChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + app = _make_app(AppMode.CHAT, is_agent=False) + result = AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "chat"} + gen_spy.assert_called_once() + + # -- ADVANCED_CHAT blocking --------------------------------------------- + def test_advanced_chat_blocking(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + + retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events") + gen_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.generate", + return_value={"result": "advanced-blocking"}, + ) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.ADVANCED_CHAT), + user=_make_user(), + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "advanced-blocking"} + assert gen_spy.call_args.kwargs.get("streaming") is False + retrieve_spy.assert_not_called() + + # -- ADVANCED_CHAT streaming -------------------------------------------- + def test_advanced_chat_streaming(self, mocker, monkeypatch): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AppExecutionParams.new", + return_value=MagicMock(workflow_run_id="wfr-1", model_dump_json=MagicMock(return_value="{}")), + ) + delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay") + # Let _build_streaming_task_on_subscribe call the real on_subscribe + # so the inner closure (line 165) actually executes. + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.ADVANCED_CHAT), + user=_make_user(), + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + # In streaming mode it should go through retrieve_events, not generate + gen_instance.retrieve_events.assert_called_once() + # The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe + delay_spy.assert_called_once() + + # -- WORKFLOW blocking -------------------------------------------------- + def test_workflow_blocking(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + gen_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.generate", + return_value={"result": "workflow-blocking"}, + ) + mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.WORKFLOW), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "workflow-blocking"} + call_kwargs = gen_spy.call_args.kwargs + assert call_kwargs.get("pause_state_config") is not None + assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id" + + # -- WORKFLOW streaming ------------------------------------------------- + def test_workflow_streaming(self, mocker, monkeypatch): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AppExecutionParams.new", + return_value=MagicMock(workflow_run_id="wfr-2", model_dump_json=MagicMock(return_value="{}")), + ) + delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay") + # Let _build_streaming_task_on_subscribe invoke the real on_subscribe + # so the inner closure (line 216) actually executes. + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + retrieve_spy = mocker.patch( + "services.app_generate_service.MessageBasedAppGenerator.retrieve_events", + return_value=iter([]), + ) + mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.WORKFLOW), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + retrieve_spy.assert_called_once() + # The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe + delay_spy.assert_called_once() + + # -- Invalid mode ------------------------------------------------------- + def test_invalid_mode_raises(self, mocker): + app = _make_app("invalid-mode", is_agent=False) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + +# --------------------------------------------------------------------------- +# generate – billing / quota +# --------------------------------------------------------------------------- +class TestGenerateBilling: + @pytest.fixture(autouse=True) + def _common(self, mocker, monkeypatch): + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + mocker.patch( + "services.app_generate_service.rate_limit_context", + _noop_rate_limit_context, + ) + + def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + quota_charge = MagicMock() + consume_mock = mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + return_value=quota_charge, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"ok": True}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + consume_mock.assert_called_once_with("tenant-id") + + def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): + from services.errors.app import QuotaExceededError + from services.errors.llm import InvokeRateLimitError + + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), + ) + + with pytest.raises(InvokeRateLimitError): + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_exception_refunds_quota_and_exits_rate_limit(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + quota_charge = MagicMock() + mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + return_value=quota_charge, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + side_effect=RuntimeError("boom"), + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + with pytest.raises(RuntimeError, match="boom"): + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + quota_charge.refund.assert_called_once() + + def test_rate_limit_exit_called_in_finally_for_blocking(self, mocker, monkeypatch): + """For non-streaming (blocking) calls, rate_limit.exit should be called in finally.""" + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + + exit_calls: list[str] = [] + + class _TrackingRateLimit(_DummyRateLimit): + def exit(self, request_id: str) -> None: + exit_calls.append(request_id) + + mocker.patch("services.app_generate_service.RateLimit", _TrackingRateLimit) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"ok": True}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + # exit is called in finally block for non-streaming + assert len(exit_calls) >= 1 + + +# --------------------------------------------------------------------------- +# _get_workflow +# --------------------------------------------------------------------------- +class TestGetWorkflow: + def test_debugger_fetches_draft(self, mocker): + draft_wf = _make_workflow() + ws = MagicMock() + ws.get_draft_workflow.return_value = draft_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER) + assert result is draft_wf + ws.get_draft_workflow.assert_called_once() + + def test_debugger_raises_when_no_draft(self, mocker): + ws = MagicMock() + ws.get_draft_workflow.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(ValueError, match="Workflow not initialized"): + AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER) + + def test_non_debugger_fetches_published(self, mocker): + pub_wf = _make_workflow() + ws = MagicMock() + ws.get_published_workflow.return_value = pub_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API) + assert result is pub_wf + ws.get_published_workflow.assert_called_once() + + def test_non_debugger_raises_when_no_published(self, mocker): + ws = MagicMock() + ws.get_published_workflow.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(ValueError, match="Workflow not published"): + AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API) + + def test_specific_workflow_id_valid_uuid(self, mocker): + valid_uuid = str(uuid.uuid4()) + specific_wf = _make_workflow(workflow_id=valid_uuid) + ws = MagicMock() + ws.get_published_workflow_by_id.return_value = specific_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid + ) + assert result is specific_wf + ws.get_published_workflow_by_id.assert_called_once() + + def test_specific_workflow_id_invalid_uuid(self, mocker): + ws = MagicMock() + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(WorkflowIdFormatError): + AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id="not-a-uuid" + ) + + def test_specific_workflow_id_not_found(self, mocker): + valid_uuid = str(uuid.uuid4()) + ws = MagicMock() + ws.get_published_workflow_by_id.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(WorkflowNotFoundError): + AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid + ) + + +# --------------------------------------------------------------------------- +# generate_single_iteration +# --------------------------------------------------------------------------- +class TestGenerateSingleIteration: + def test_advanced_chat_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + gen_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + iter_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.single_iteration_generate", + return_value={"event": "iteration"}, + ) + app = _make_app(AppMode.ADVANCED_CHAT) + result = AppGenerateService.generate_single_iteration( + app_model=app, user=_make_user(), node_id="n1", args={"k": "v"} + ) + iter_spy.assert_called_once() + assert result == {"event": "iteration"} + + def test_workflow_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + iter_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.single_iteration_generate", + return_value={"event": "wf-iteration"}, + ) + app = _make_app(AppMode.WORKFLOW) + result = AppGenerateService.generate_single_iteration( + app_model=app, user=_make_user(), node_id="n1", args={"k": "v"} + ) + iter_spy.assert_called_once() + assert result == {"event": "wf-iteration"} + + def test_invalid_mode_raises(self, mocker): + app = _make_app(AppMode.CHAT) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate_single_iteration(app_model=app, user=_make_user(), node_id="n1", args={}) + + +# --------------------------------------------------------------------------- +# generate_single_loop +# --------------------------------------------------------------------------- +class TestGenerateSingleLoop: + def test_advanced_chat_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + loop_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.single_loop_generate", + return_value={"event": "loop"}, + ) + app = _make_app(AppMode.ADVANCED_CHAT) + result = AppGenerateService.generate_single_loop( + app_model=app, user=_make_user(), node_id="n1", args=MagicMock() + ) + loop_spy.assert_called_once() + assert result == {"event": "loop"} + + def test_workflow_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + loop_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.single_loop_generate", + return_value={"event": "wf-loop"}, + ) + app = _make_app(AppMode.WORKFLOW) + result = AppGenerateService.generate_single_loop( + app_model=app, user=_make_user(), node_id="n1", args=MagicMock() + ) + loop_spy.assert_called_once() + assert result == {"event": "wf-loop"} + + def test_invalid_mode_raises(self, mocker): + app = _make_app(AppMode.COMPLETION) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate_single_loop(app_model=app, user=_make_user(), node_id="n1", args=MagicMock()) + + +# --------------------------------------------------------------------------- +# generate_more_like_this +# --------------------------------------------------------------------------- +class TestGenerateMoreLikeThis: + def test_delegates_to_completion_generator(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate_more_like_this", + return_value={"result": "similar"}, + ) + result = AppGenerateService.generate_more_like_this( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + message_id="msg-1", + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + assert result == {"result": "similar"} + gen_spy.assert_called_once() + assert gen_spy.call_args.kwargs["stream"] is True + + +# --------------------------------------------------------------------------- +# get_response_generator +# --------------------------------------------------------------------------- +class TestGetResponseGenerator: + def test_non_ended_workflow_run(self, mocker): + app = _make_app(AppMode.ADVANCED_CHAT) + workflow_run = MagicMock() + workflow_run.id = "run-1" + workflow_run.status.is_ended.return_value = False + + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([{"event": "started"}]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run) + gen_instance.retrieve_events.assert_called_once() + + def test_ended_workflow_run_still_returns_generator(self, mocker): + """Even when the run is ended, the current code still returns a generator (TODO branch).""" + app = _make_app(AppMode.WORKFLOW) + workflow_run = MagicMock() + workflow_run.id = "run-2" + workflow_run.status.is_ended.return_value = True + + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run) + # current impl falls through the TODO and still creates a generator + gen_instance.retrieve_events.assert_called_once() diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index 5099362e00..3c0db51cd2 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -1,9 +1,12 @@ import datetime -from unittest.mock import Mock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.orm import Session +from enums.cloud_plan import CloudPlan +from services import clear_free_plan_tenant_expired_logs as service_module from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs @@ -156,13 +159,453 @@ class TestClearFreePlanTenantExpiredLogs: # Should call delete for each table that has records assert mock_session.query.return_value.where.return_value.delete.called - def test_clear_message_related_tables_logging_output( - self, mock_session, sample_message_ids, sample_records, capsys + def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes( + self, mock_session, sample_message_ids ): - """Test that logging output is generated.""" + record = Mock() + record.id = "record-1" + record.to_dict.side_effect = Exception("Serialization error") + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.query.return_value.where.return_value.all.return_value = [record] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) - pass + mock_storage.save.assert_not_called() + assert mock_session.query.return_value.where.return_value.delete.called + + +class _ImmediateFuture: + def __init__(self, fn, args, kwargs): + self._fn = fn + self._args = args + self._kwargs = kwargs + + def result(self): + return self._fn(*self._args, **self._kwargs) + + +class _ImmediateExecutor: + def __init__(self, *args, **kwargs) -> None: + self.submitted: list[tuple[object, tuple[object, ...], dict[str, object]]] = [] + + def submit(self, fn, *args, **kwargs): + self.submitted.append((fn, args, kwargs)) + return _ImmediateFuture(fn, args, kwargs) + + +def _session_wrapper_for_no_autoflush(session: Mock) -> Mock: + """ + ClearFreePlanTenantExpiredLogs.process_tenant uses: + with Session(db.engine).no_autoflush as session: + so Session(db.engine) must return an object with a no_autoflush context manager. + """ + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + + wrapper = MagicMock() + wrapper.no_autoflush = cm + return wrapper + + +def _session_wrapper_for_direct(session: Mock) -> Mock: + """ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:""" + wrapper = MagicMock() + wrapper.__enter__.return_value = session + wrapper.__exit__.return_value = None + return wrapper + + +def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -> None: + flask_app = service_module.Flask("test-app") + + monkeypatch.setattr( + service_module, + "db", + SimpleNamespace( + engine=object(), + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace( + all=lambda: [SimpleNamespace(id="app-1"), SimpleNamespace(id="app-2")] + ) + ), + ), + ) + + mock_storage = MagicMock() + monkeypatch.setattr(service_module, "storage", mock_storage) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + + clear_related = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", clear_related) + + # Session sequence for messages, conversations, workflow_app_logs loops: + # - messages: one batch then empty + # - conversations: one batch then empty + # - workflow app logs: one batch then empty + msg1 = SimpleNamespace(id="m1", to_dict=lambda: {"id": "m1"}) + conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"}) + log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"}) + + def make_query_with_batches(batches: list[list[object]]): + q = MagicMock() + q.where.return_value = q + q.limit.return_value = q + q.all.side_effect = batches + q.delete.return_value = 1 + return q + + msg_session_1 = MagicMock() + msg_session_1.query.side_effect = ( + lambda model: make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() + ) + msg_session_1.commit.return_value = None + + msg_session_2 = MagicMock() + msg_session_2.query.side_effect = ( + lambda model: make_query_with_batches([[]]) if model == service_module.Message else MagicMock() + ) + msg_session_2.commit.return_value = None + + conv_session_1 = MagicMock() + conv_session_1.query.side_effect = ( + lambda model: make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() + ) + conv_session_1.commit.return_value = None + + conv_session_2 = MagicMock() + conv_session_2.query.side_effect = ( + lambda model: make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() + ) + conv_session_2.commit.return_value = None + + wal_session_1 = MagicMock() + wal_session_1.query.side_effect = ( + lambda model: make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() + ) + wal_session_1.commit.return_value = None + + wal_session_2 = MagicMock() + wal_session_2.query.side_effect = ( + lambda model: make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() + ) + wal_session_2.commit.return_value = None + + session_wrappers = [ + _session_wrapper_for_no_autoflush(msg_session_1), + _session_wrapper_for_no_autoflush(msg_session_2), + _session_wrapper_for_no_autoflush(conv_session_1), + _session_wrapper_for_no_autoflush(conv_session_2), + _session_wrapper_for_no_autoflush(wal_session_1), + _session_wrapper_for_no_autoflush(wal_session_2), + ] + + monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_select(*_args, **_kwargs): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(service_module, "select", fake_select) + + # Repositories for workflow node executions and workflow runs + node_repo = MagicMock() + node_repo.get_expired_executions_batch.side_effect = [[SimpleNamespace(id="ne-1")], []] + node_repo.delete_executions_by_ids.return_value = 1 + + run_repo = MagicMock() + run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []] + run_repo.delete_runs_by_ids.return_value = 1 + + monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_node_execution_repository", + lambda _sm: node_repo, + ) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda _sm: run_repo, + ) + + ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=10) + + # messages backup, conversations backup, node executions backup, runs backup, workflow app logs backup + assert mock_storage.save.call_count >= 5 + clear_related.assert_called() + + +def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + + # Total tenant count query + count_session = MagicMock() + count_query = MagicMock() + count_query.count.return_value = 2 + count_session.query.return_value = count_query + + monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) + + # Avoid LocalProxy usage + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + echo_mock = MagicMock() + monkeypatch.setattr(service_module.click, "echo", echo_mock) + + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", True) + + def fake_get_info(tenant_id: str): + if tenant_id == "t_sandbox": + return {"subscription": {"plan": CloudPlan.SANDBOX}} + if tenant_id == "t_fail": + raise RuntimeError("boom") + return {"subscription": {"plan": "team"}} + + monkeypatch.setattr(service_module.BillingService, "get_info", staticmethod(fake_get_info)) + + process_tenant_mock = MagicMock(side_effect=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("err"))) + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + logger_exc = MagicMock() + monkeypatch.setattr(service_module.logger, "exception", logger_exc) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=["t_sandbox", "t_paid", "t_fail"]) + + # Only sandbox tenant should attempt processing, and its failure should be swallowed + logged. + assert process_tenant_mock.call_count == 1 + assert logger_exc.call_count >= 1 + + +def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) + fixed_now = started_at + datetime.timedelta(hours=2) + + class FixedDateTime(datetime.datetime): + @classmethod + def now(cls, tz=None): + return fixed_now + + monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime) + + # Avoid LocalProxy usage + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + + # Sessions used: + # 1) total tenant count + # 2) per-batch tenant scan (count + tenant list) + total_session = MagicMock() + total_query = MagicMock() + total_query.count.return_value = 250 + total_session.query.return_value = total_query + + batch_session = MagicMock() + q1 = MagicMock() + q1.where.return_value = q1 + q1.count.return_value = 200 + q2 = MagicMock() + q2.where.return_value = q2 + q2.count.return_value = 200 + q3 = MagicMock() + q3.where.return_value = q3 + q3.count.return_value = 200 + q4 = MagicMock() + q4.where.return_value = q4 + q4.count.return_value = 50 # choose this interval, then scale it + + rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")] + q_rs = MagicMock() + q_rs.where.return_value = q_rs + q_rs.order_by.return_value = rows + + batch_session.query.side_effect = [q1, q2, q3, q4, q_rs] + + sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] + monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) + + process_tenant_mock = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) + + # Should submit/process tenants from the batch query + assert process_tenant_mock.call_count == 2 + + +def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + + count_session = MagicMock() + count_query = MagicMock() + count_query.count.return_value = 100 + count_session.query.return_value = count_query + monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) + + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + echo_mock = MagicMock() + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", echo_mock) + + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", MagicMock()) + + tenant_ids = [f"t{i}" for i in range(100)] + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=tenant_ids) + + assert any("Processed 100 tenants" in str(call.args[0]) for call in echo_mock.call_args_list) + + +def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) + # Keep the total range smaller than the minimum interval (1 hour) so the loop runs once. + fixed_now = started_at + datetime.timedelta(minutes=30) + + class FixedDateTime(datetime.datetime): + @classmethod + def now(cls, tz=None): + return fixed_now + + monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime) + + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + + total_session = MagicMock() + total_query = MagicMock() + total_query.count.return_value = 250 + total_session.query.return_value = total_query + + batch_session = MagicMock() + # Count results for all 5 intervals, all > 100 => take the for-else path. + count_queries = [] + for _ in range(5): + q = MagicMock() + q.where.return_value = q + q.count.return_value = 200 + count_queries.append(q) + + rows = [SimpleNamespace(id="tenant-a")] + q_rs = MagicMock() + q_rs.where.return_value = q_rs + q_rs.order_by.return_value = rows + + batch_session.query.side_effect = [*count_queries, q_rs] + + sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] + monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) + + process_tenant_mock = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) + + assert process_tenant_mock.call_count == 1 + assert len(count_queries) == 5 + assert batch_session.query.call_count >= 6 + + +def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None: + flask_app = service_module.Flask("test-app") + + monkeypatch.setattr( + service_module, + "db", + SimpleNamespace( + engine=object(), + session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="app-1")])), + ), + ) + mock_storage = MagicMock() + monkeypatch.setattr(service_module, "storage", mock_storage) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", MagicMock()) + + # Make message/conversation/workflow_app_log loops no-op (empty immediately) + empty_session = MagicMock() + q_empty = MagicMock() + q_empty.where.return_value = q_empty + q_empty.limit.return_value = q_empty + q_empty.all.return_value = [] + empty_session.query.return_value = q_empty + empty_session.commit.return_value = None + session_wrappers = [ + _session_wrapper_for_no_autoflush(empty_session), + _session_wrapper_for_no_autoflush(empty_session), + _session_wrapper_for_no_autoflush(empty_session), + ] + monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_select(*_args, **_kwargs): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(service_module, "select", fake_select) + + # Repos: first returns exactly batch items -> no "< batch" break, second returns [] -> hit the len==0 break. + node_repo = MagicMock() + node_repo.get_expired_executions_batch.side_effect = [ + [SimpleNamespace(id="ne-1"), SimpleNamespace(id="ne-2")], + [], + ] + node_repo.delete_executions_by_ids.return_value = 2 + + run_repo = MagicMock() + run_repo.get_expired_runs_batch.side_effect = [ + [ + SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"}), + SimpleNamespace(id="wr-2", to_dict=lambda: {"id": "wr-2"}), + ], + [], + ] + run_repo.delete_runs_by_ids.return_value = 2 + + monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_node_execution_repository", + lambda _sm: node_repo, + ) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda _sm: run_repo, + ) + + ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=2) + + assert node_repo.get_expired_executions_batch.call_count == 2 + assert run_repo.get_expired_runs_batch.call_count == 2 diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index d8ecdf45fd..75551531a2 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -1,18 +1,29 @@ """ Comprehensive unit tests for ConversationService. -This file keeps non-SQL guard/unit tests. -SQL-related tests were migrated to testcontainers integration tests. +This file provides complete test coverage for all ConversationService methods. +Tests are organized by functionality and include edge cases, error handling, +and both positive and negative test scenarios. """ -from datetime import datetime +from datetime import datetime, timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch +import pytest +from sqlalchemy import asc, desc + from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account -from models.model import App, Conversation, EndUser +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models import Account, ConversationVariable +from models.model import App, Conversation, EndUser, Message from services.conversation_service import ConversationService -from services.message_service import MessageService +from services.errors.conversation import ( + ConversationNotExistsError, + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) +from services.errors.message import MessageNotExistsError class ConversationServiceTestDataFactory: @@ -116,6 +127,84 @@ class ConversationServiceTestDataFactory: setattr(conversation, key, value) return conversation + @staticmethod + def create_message_mock( + message_id: str = "msg-123", + conversation_id: str = "conv-123", + app_id: str = "app-123", + **kwargs, + ) -> Mock: + """ + Create a mock Message object. + + Args: + message_id: Unique identifier for the message + conversation_id: Associated conversation identifier + app_id: Associated app identifier + **kwargs: Additional attributes to set on the mock + + Returns: + Mock Message object with specified attributes + """ + message = create_autospec(Message, instance=True) + message.id = message_id + message.conversation_id = conversation_id + message.app_id = app_id + message.query = kwargs.get("query", "Test message content") + message.created_at = kwargs.get("created_at", datetime.utcnow()) + for key, value in kwargs.items(): + setattr(message, key, value) + return message + + @staticmethod + def create_conversation_variable_mock( + variable_id: str = "var-123", + conversation_id: str = "conv-123", + app_id: str = "app-123", + **kwargs, + ) -> Mock: + """ + Create a mock ConversationVariable object. + + Args: + variable_id: Unique identifier for the variable + conversation_id: Associated conversation identifier + app_id: Associated app identifier + **kwargs: Additional attributes to set on the mock + + Returns: + Mock ConversationVariable object with specified attributes + """ + variable = create_autospec(ConversationVariable, instance=True) + variable.id = variable_id + variable.conversation_id = conversation_id + variable.app_id = app_id + variable.data = {"name": kwargs.get("name", "test_var"), "value": kwargs.get("value", "test_value")} + variable.created_at = kwargs.get("created_at", datetime.utcnow()) + variable.updated_at = kwargs.get("updated_at", datetime.utcnow()) + + # Mock to_variable method + mock_variable = Mock() + mock_variable.id = variable_id + mock_variable.name = kwargs.get("name", "test_var") + mock_variable.value_type = kwargs.get("value_type", "string") + mock_variable.value = kwargs.get("value", "test_value") + mock_variable.description = kwargs.get("description", "") + mock_variable.selector = kwargs.get("selector", {}) + mock_variable.model_dump.return_value = { + "id": variable_id, + "name": kwargs.get("name", "test_var"), + "value_type": kwargs.get("value_type", "string"), + "value": kwargs.get("value", "test_value"), + "description": kwargs.get("description", ""), + "selector": kwargs.get("selector", {}), + } + variable.to_variable.return_value = mock_variable + + for key, value in kwargs.items(): + setattr(variable, key, value) + return variable + class TestConversationServicePagination: """Test conversation pagination operations.""" @@ -175,99 +264,958 @@ class TestConversationServicePagination: assert result.limit == 20 -class TestConversationServiceMessageCreation: - """ - Test message creation and pagination. +class TestConversationServiceHelpers: + """Test helper methods in ConversationService.""" - Tests MessageService operations for creating and retrieving messages - within conversations. - """ - - def test_pagination_returns_empty_when_no_user(self): + def test_get_sort_params_with_descending_sort(self): """ - Test that pagination returns empty result when user is None. + Test _get_sort_params with descending sort prefix. - This ensures proper handling of unauthenticated requests. + When sort_by starts with '-', should return field name and desc function. + """ + # Act + field, direction = ConversationService._get_sort_params("-updated_at") + + # Assert + assert field == "updated_at" + assert direction == desc + + def test_get_sort_params_with_ascending_sort(self): + """ + Test _get_sort_params with ascending sort. + + When sort_by doesn't start with '-', should return field name and asc function. + """ + # Act + field, direction = ConversationService._get_sort_params("created_at") + + # Assert + assert field == "created_at" + assert direction == asc + + def test_build_filter_condition_with_descending_sort(self): + """ + Test _build_filter_condition with descending sort direction. + + Should create a less-than filter condition. """ # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() + mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_conversation.updated_at = datetime.utcnow() # Act - result = MessageService.pagination_by_first_id( - app_model=app_model, - user=None, - conversation_id="conv-123", - first_id=None, - limit=10, + condition = ConversationService._build_filter_condition( + sort_field="updated_at", + sort_direction=desc, + reference_conversation=mock_conversation, ) # Assert - assert result.data == [] - assert result.has_more is False + # The condition should be a comparison expression + assert condition is not None - def test_pagination_returns_empty_when_no_conversation_id(self): + def test_build_filter_condition_with_ascending_sort(self): """ - Test that pagination returns empty result when conversation_id is None. + Test _build_filter_condition with ascending sort direction. - This ensures proper handling of invalid requests. + Should create a greater-than filter condition. + """ + # Arrange + mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_conversation.created_at = datetime.utcnow() + + # Act + condition = ConversationService._build_filter_condition( + sort_field="created_at", + sort_direction=asc, + reference_conversation=mock_conversation, + ) + + # Assert + # The condition should be a comparison expression + assert condition is not None + + +class TestConversationServiceGetConversation: + """Test conversation retrieval operations.""" + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_account(self, mock_db_session): + """ + Test successful conversation retrieval with account user. + + Should return conversation when found with proper filters. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_account_id=user.id, from_source="console" + ) + + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + mock_db_session.query.assert_called_once_with(Conversation) + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_end_user(self, mock_db_session): + """ + Test successful conversation retrieval with end user. + + Should return conversation when found with proper filters for API user. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_end_user_id=user.id, from_source="api" + ) + + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + + @patch("services.conversation_service.db.session") + def test_get_conversation_not_found_raises_error(self, mock_db_session): + """ + Test that get_conversation raises error when conversation not found. + + Should raise ConversationNotExistsError when no matching conversation found. """ # Arrange app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() - # Act - result = MessageService.pagination_by_first_id( - app_model=app_model, - user=user, - conversation_id="", - first_id=None, - limit=10, - ) + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = None - # Assert - assert result.data == [] - assert result.has_more is False + # Act & Assert + with pytest.raises(ConversationNotExistsError): + ConversationService.get_conversation(app_model, "conv-123", user) -class TestConversationServiceSummarization: - """ - Test conversation summarization (auto-generated names). +class TestConversationServiceRename: + """Test conversation rename operations.""" - Tests the auto_generate_name functionality that creates conversation - titles based on the first message. - """ - - @patch("services.conversation_service.db.session", autospec=True) - @patch("services.conversation_service.ConversationService.get_conversation", autospec=True) - @patch("services.conversation_service.ConversationService.auto_generate_name", autospec=True) - def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_rename_with_manual_name(self, mock_get_conversation, mock_db_session): """ - Test renaming conversation with auto-generation enabled. + Test renaming conversation with manual name. - When auto_generate is True, the service should call the auto_generate_name - method to generate a new name for the conversation. + Should update conversation name and timestamp when auto_generate is False. """ # Arrange app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() conversation = ConversationServiceTestDataFactory.create_conversation_mock() - conversation.name = "Auto-generated Name" - # Mock the conversation lookup to return our test conversation mock_get_conversation.return_value = conversation - # Mock the auto_generate_name method to return the conversation + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id="conv-123", + user=user, + name="New Name", + auto_generate=False, + ) + + # Assert + assert result == conversation + assert conversation.name == "New Name" + mock_db_session.commit.assert_called_once() + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.ConversationService.auto_generate_name") + def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): + """ + Test renaming conversation with auto-generation. + + Should call auto_generate_name when auto_generate is True. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation mock_auto_generate.return_value = conversation # Act result = ConversationService.rename( app_model=app_model, - conversation_id=conversation.id, + conversation_id="conv-123", user=user, - name="", + name=None, auto_generate=True, ) # Assert - mock_auto_generate.assert_called_once_with(app_model, conversation) assert result == conversation + mock_auto_generate.assert_called_once_with(app_model, conversation) + + +class TestConversationServiceAutoGenerateName: + """Test conversation auto-name generation operations.""" + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_success(self, mock_llm_generator, mock_db_session): + """ + Test successful auto-generation of conversation name. + + Should generate name using LLMGenerator and update conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = message + + # Mock LLM generator + mock_llm_generator.generate_conversation_name.return_value = "Generated Name" + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + assert conversation.name == "Generated Name" + mock_llm_generator.generate_conversation_name.assert_called_once_with( + app_model.tenant_id, message.query, conversation.id, app_model.id + ) + mock_db_session.commit.assert_called_once() + + @patch("services.conversation_service.db.session") + def test_auto_generate_name_no_message_raises_error(self, mock_db_session): + """ + Test auto-generation fails when no message found. + + Should raise MessageNotExistsError when conversation has no messages. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + # Mock database query to return None + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = None + + # Act & Assert + with pytest.raises(MessageNotExistsError): + ConversationService.auto_generate_name(app_model, conversation) + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_handles_llm_exception(self, mock_llm_generator, mock_db_session): + """ + Test auto-generation handles LLM generator exceptions gracefully. + + Should continue without name when LLMGenerator fails. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = message + + # Mock LLM generator to raise exception + mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error") + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + # Name should remain unchanged due to exception + mock_db_session.commit.assert_called_once() + + +class TestConversationServiceDelete: + """Test conversation deletion operations.""" + + @patch("services.conversation_service.delete_conversation_related_data") + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_delete_success(self, mock_get_conversation, mock_db_session, mock_delete_task): + """ + Test successful conversation deletion. + + Should delete conversation and schedule cleanup task. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock(name="Test App") + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Act + ConversationService.delete(app_model, "conv-123", user) + + # Assert + mock_db_session.delete.assert_called_once_with(conversation) + mock_db_session.commit.assert_called_once() + mock_delete_task.delay.assert_called_once_with(conversation.id) + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session): + """ + Test deletion handles exceptions and rolls back transaction. + + Should rollback database changes when deletion fails. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_db_session.delete.side_effect = Exception("Database Error") + + # Act & Assert + with pytest.raises(Exception, match="Database Error"): + ConversationService.delete(app_model, "conv-123", user) + + # Assert rollback was called + mock_db_session.rollback.assert_called_once() + + +class TestConversationServiceConversationalVariable: + """Test conversational variable operations.""" + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_success(self, mock_get_conversation, mock_session_factory): + """ + Test successful retrieval of conversational variables. + + Should return paginated list of variables for conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + variable1 = ConversationServiceTestDataFactory.create_conversation_variable_mock() + variable2 = ConversationServiceTestDataFactory.create_conversation_variable_mock(variable_id="var-456") + + mock_session.scalars.return_value.all.return_value = [variable1, variable2] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 2 + assert result.limit == 10 + assert result.has_more is False + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_with_last_id(self, mock_get_conversation, mock_session_factory): + """ + Test retrieval of variables with last_id pagination. + + Should filter variables created after last_id. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( + created_at=datetime.utcnow() - timedelta(hours=1) + ) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=datetime.utcnow()) + + mock_session.scalar.return_value = last_variable + mock_session.scalars.return_value.all.return_value = [variable] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="var-123", + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + assert result.limit == 10 + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_last_id_not_found_raises_error( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that invalid last_id raises ConversationVariableNotExistsError. + + Should raise error when last_id doesn't exist. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="invalid-id", + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.dify_config") + def test_get_conversational_variable_with_name_filter_mysql( + self, mock_config, mock_get_conversation, mock_session_factory + ): + """ + Test variable filtering by name for MySQL databases. + + Should apply JSON extraction filter for variable names. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_config.DB_TYPE = "mysql" + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + variable_name="test_var", + ) + + # Assert - JSON filter should be applied + assert mock_session.scalars.called + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.dify_config") + def test_get_conversational_variable_with_name_filter_postgresql( + self, mock_config, mock_get_conversation, mock_session_factory + ): + """ + Test variable filtering by name for PostgreSQL databases. + + Should apply JSON extraction filter for variable names. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_config.DB_TYPE = "postgresql" + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + variable_name="test_var", + ) + + # Assert - JSON filter should be applied + assert mock_session.scalars.called + + +class TestConversationServiceUpdateVariable: + """Test conversation variable update operations.""" + + @patch("services.conversation_service.variable_factory") + @patch("services.conversation_service.ConversationVariableUpdater") + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_success( + self, mock_get_conversation, mock_session_factory, mock_updater_class, mock_variable_factory + ): + """ + Test successful update of conversation variable. + + Should update variable value and return updated data. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="string") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": "new_value"} + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="new_value", + ) + + # Assert + assert result["id"] == "var-123" + assert result["value"] == "new_value" + mock_updater.update.assert_called_once_with("conv-123", updated_variable) + mock_updater.flush.assert_called_once() + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_not_found_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when variable doesn't exist. + + Should raise ConversationVariableNotExistsError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="invalid-id", + user=user, + new_value="new_value", + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_type_mismatch_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when value type doesn't match expected type. + + Should raise ConversationVariableTypeMismatchError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="number") + mock_session.scalar.return_value = existing_variable + + # Act & Assert - Try to set string value for number variable + with pytest.raises(ConversationVariableTypeMismatchError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="string_value", # Wrong type + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_integer_number_compatibility( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that integer type accepts number values. + + Should allow number values for integer type variables. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="integer") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": 42} + + with ( + patch("services.conversation_service.variable_factory") as mock_variable_factory, + patch("services.conversation_service.ConversationVariableUpdater") as mock_updater_class, + ): + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value=42, # Number value for integer type + ) + + # Assert + assert result["value"] == 42 + mock_updater.update.assert_called_once() + + +class TestConversationServicePaginationAdvanced: + """Advanced pagination tests for ConversationService.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_last_id_not_found(self, mock_session_factory): + """ + Test pagination with invalid last_id raises error. + + Should raise LastConversationNotExistsError when last_id doesn't exist. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act & Assert + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id="invalid-id", + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_exclude_ids(self, mock_session_factory): + """ + Test pagination with exclude_ids filter. + + Should exclude specified conversation IDs from results. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=["excluded-123"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_has_more_detection(self, mock_session_factory): + """ + Test pagination has_more detection logic. + + Should set has_more=True when there are more results beyond limit. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + # Return exactly limit items to trigger has_more check + conversations = [ + ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=f"conv-{i}") for i in range(20) + ] + mock_session.scalars.return_value.all.return_value = conversations + mock_session.scalar.return_value = conversations[-1] + + # Mock count query to return > 0 + mock_session.scalar.return_value = 5 # Additional items exist + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is True + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_different_sort_by(self, mock_session_factory): + """ + Test pagination with different sort fields. + + Should handle various sort_by parameters correctly. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Test different sort fields + sort_fields = ["created_at", "-updated_at", "name", "-status"] + + for sort_by in sort_fields: + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + sort_by=sort_by, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + +class TestConversationServiceEdgeCases: + """Test edge cases and error scenarios.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_with_end_user_api_source(self, mock_session_factory): + """ + Test pagination correctly handles EndUser with API source. + + Should use 'api' as from_source for EndUser instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source="api", from_end_user_id="user-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + @patch("services.conversation_service.session_factory") + def test_pagination_with_account_console_source(self, mock_session_factory): + """ + Test pagination correctly handles Account with console source. + + Should use 'console' as from_source for Account instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source="console", from_account_id="account-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + def test_pagination_with_include_ids_filter(self): + """ + Test pagination with include_ids filter. + + Should only return conversations with IDs in include_ids list. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv-123", "conv-456"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + # Verify that include_ids filter was applied + assert mock_session.scalars.called + + def test_pagination_with_empty_exclude_ids(self): + """ + Test pagination with empty exclude_ids list. + + Should handle empty exclude_ids gracefully. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=[], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is False diff --git a/api/tests/unit_tests/services/test_end_user_service.py b/api/tests/unit_tests/services/test_end_user_service.py index 7f087a17d8..a3b1f46436 100644 --- a/api/tests/unit_tests/services/test_end_user_service.py +++ b/api/tests/unit_tests/services/test_end_user_service.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from models.model import App, EndUser +from models.model import App, DefaultEndUserSessionID, EndUser from services.end_user_service import EndUserService @@ -44,6 +44,145 @@ class TestEndUserServiceFactory: return end_user +class TestEndUserServiceGetEndUserById: + """Unit tests for EndUserService.get_end_user_by_id method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory): + """Test successful retrieval of end user by ID.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = mock_end_user + + # Act + result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + assert result == mock_end_user + mock_session.query.assert_called_once_with(EndUser) + mock_query.where.assert_called_once() + mock_query.first.assert_called_once() + mock_context.__enter__.assert_called_once() + mock_context.__exit__.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class): + """Test retrieval of non-existent end user returns None.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act + result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + assert result is None + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class): + """Test that query parameters are correctly applied.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act + EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + # Verify the where clause was called with the correct conditions + call_args = mock_query.where.call_args[0] + assert len(call_args) == 3 + # Check that the conditions match the expected filters + # (We can't easily test the exact conditions without importing SQLAlchemy) + + +class TestEndUserServiceGetOrCreateEndUser: + """Unit tests for EndUserService.get_or_create_end_user method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") + def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory): + """Test get_or_create_end_user with specific user_id.""" + # Arrange + app_mock = factory.create_app_mock() + user_id = "user-123" + expected_end_user = factory.create_end_user_mock() + mock_get_or_create_by_type.return_value = expected_end_user + + # Act + result = EndUserService.get_or_create_end_user(app_mock, user_id) + + # Assert + assert result == expected_end_user + mock_get_or_create_by_type.assert_called_once_with( + InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id + ) + + @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") + def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory): + """Test get_or_create_end_user without user_id (None).""" + # Arrange + app_mock = factory.create_app_mock() + expected_end_user = factory.create_end_user_mock() + mock_get_or_create_by_type.return_value = expected_end_user + + # Act + result = EndUserService.get_or_create_end_user(app_mock, None) + + # Assert + assert result == expected_end_user + mock_get_or_create_by_type.assert_called_once_with( + InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None + ) + + class TestEndUserServiceGetOrCreateEndUserByType: """ Unit tests for EndUserService.get_or_create_end_user_by_type method. @@ -60,6 +199,191 @@ class TestEndUserServiceGetOrCreateEndUserByType: """Provide test data factory.""" return TestEndUserServiceFactory() + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory): + """Test creating a new end user with specific user_id.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None # No existing user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + # Verify new EndUser was created with correct parameters + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + added_user = mock_session.add.call_args[0][0] + assert added_user.tenant_id == tenant_id + assert added_user.app_id == app_id + assert added_user.type == type_enum + assert added_user.session_id == user_id + assert added_user.external_user_id == user_id + assert added_user._is_anonymous is False + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory): + """Test creating a new end user with default session ID.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = None + type_enum = InvokeFrom.WEB_APP + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None # No existing user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + added_user = mock_session.add.call_args[0][0] + assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert added_user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + @patch("services.end_user_service.logger") + def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory): + """Test retrieving existing user with same type.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user = factory.create_end_user_mock( + tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = existing_user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + assert result == existing_user + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + mock_logger.info.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + @patch("services.end_user_service.logger") + def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory): + """Test upgrading existing user with different type.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + old_type = InvokeFrom.WEB_APP + new_type = InvokeFrom.SERVICE_API + + existing_user = factory.create_end_user_mock( + tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = existing_user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + assert result == existing_user + assert existing_user.type == new_type + mock_session.commit.assert_called_once() + mock_logger.info.assert_called_once() + logger_call_args = mock_logger.info.call_args[0] + assert "Upgrading legacy EndUser" in logger_call_args[0] + # The old and new types are passed as separate arguments + assert mock_logger.info.call_args[0][1] == existing_user.id + assert mock_logger.info.call_args[0][2] == old_type + assert mock_logger.info.call_args[0][3] == new_type + assert mock_logger.info.call_args[0][4] == user_id + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory): + """Test that query ordering prioritizes exact type matches.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + target_type = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None + + # Act + EndUserService.get_or_create_end_user_by_type( + type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + mock_query.order_by.assert_called_once() + # Verify that case statement is used for ordering + order_by_call = mock_query.order_by.call_args[0][0] + # The exact structure depends on SQLAlchemy's case implementation + # but we can verify it was called + # Test 10: Session context manager properly closes @patch("services.end_user_service.Session") @patch("services.end_user_service.db") @@ -93,3 +417,425 @@ class TestEndUserServiceGetOrCreateEndUserByType: # Verify context manager was entered and exited mock_context.__enter__.assert_called_once() mock_context.__exit__.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_all_invokefrom_types_supported(self, mock_db, mock_session_class): + """Test that all InvokeFrom enum values are supported.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + + for invoke_type in InvokeFrom: + with patch("services.end_user_service.Session") as mock_session_class: + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + added_user = mock_session.add.call_args[0][0] + assert added_user.type == invoke_type + + +class TestEndUserServiceCreateEndUserBatch: + """Unit tests for EndUserService.create_end_user_batch method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_empty_app_ids(self, mock_db, mock_session_class): + """Test batch creation with empty app_ids list.""" + # Arrange + tenant_id = "tenant-123" + app_ids: list[str] = [] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert result == {} + mock_session_class.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_default_session_id(self, mock_db, mock_session_class): + """Test batch creation with empty user_id (uses default session).""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + user_id = "" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 2 + for app_id, end_user in result.items(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class): + """Test that duplicate app_ids are deduplicated while preserving order.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + # Should have 3 unique app_ids in original order + assert len(result) == 3 + assert "app-456" in result + assert "app-789" in result + assert "app-123" in result + + # Verify the order is preserved + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 3 + assert added_users[0].app_id == "app-456" + assert added_users[1].app_id == "app-789" + assert added_users[2].app_id == "app-123" + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory): + """Test batch creation when all users already exist.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user1 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + existing_user2 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1, existing_user2] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 2 + assert result["app-456"] == existing_user1 + assert result["app-789"] == existing_user2 + mock_session.add_all.assert_not_called() + mock_session.commit.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory): + """Test batch creation with some existing and some new users.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-123"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user1 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + # app-789 and app-123 don't exist + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 3 + assert result["app-456"] == existing_user1 + assert "app-789" in result + assert "app-123" in result + + # Should create 2 new users + mock_session.add_all.assert_called_once() + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 2 + + mock_session.commit.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory): + """Test batch creation handles duplicates in existing users gracefully.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + # Simulate duplicate records in database + existing_user1 = factory.create_end_user_mock( + user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + existing_user2 = factory.create_end_user_mock( + user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1, existing_user2] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 1 + # Should prefer the first one found + assert result["app-456"] == existing_user1 + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class): + """Test batch creation with all InvokeFrom types.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + + for invoke_type in InvokeFrom: + with patch("services.end_user_service.Session") as mock_session_class: + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + added_user = mock_session.add_all.call_args[0][0][0] + assert added_user.type == invoke_type + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory): + """Test batch creation with single app_id.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 1 + assert "app-456" in result + mock_session.add_all.assert_called_once() + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 1 + assert added_users[0].app_id == "app-456" + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class): + """Test batch creation correctly sets anonymous flag.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + + # Test with regular user ID + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act - authenticated user + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789" + ) + + # Assert + added_users = mock_session.add_all.call_args[0][0] + for user in added_users: + assert user._is_anonymous is False + + # Test with default session ID + mock_session.reset_mock() + mock_query.reset_mock() + mock_query.all.return_value = [] + + # Act - anonymous user + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=app_ids, + user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID, + ) + + # Assert + added_users = mock_session.add_all.call_args[0][0] + for user in added_users: + assert user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_efficient_single_query(self, mock_db, mock_session_class): + """Test that batch creation uses efficient single query for existing users.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-123"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) + + # Assert + # Should make exactly one query to check for existing users + mock_session.query.assert_called_once_with(EndUser) + mock_query.where.assert_called_once() + mock_query.all.assert_called_once() + + # Verify the where clause uses .in_() for app_ids + where_call = mock_query.where.call_args[0] + # The exact structure depends on SQLAlchemy implementation + # but we can verify it was called with the right parameters + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_session_context_manager(self, mock_db, mock_session_class): + """Test that batch creation properly uses session context manager.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) + + # Assert + mock_context.__enter__.assert_called_once() + mock_context.__exit__.assert_called_once() + mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py new file mode 100644 index 0000000000..b7259c3e82 --- /dev/null +++ b/api/tests/unit_tests/services/test_file_service.py @@ -0,0 +1,420 @@ +import base64 +import hashlib +import os +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker +from werkzeug.exceptions import NotFound + +from configs import dify_config +from models.enums import CreatorUserRole +from models.model import Account, EndUser, UploadFile +from services.errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError +from services.file_service import FileService + + +class TestFileService: + @pytest.fixture + def mock_db_session(self): + session = MagicMock(spec=Session) + # Mock context manager behavior + session.__enter__.return_value = session + return session + + @pytest.fixture + def mock_session_maker(self, mock_db_session): + maker = MagicMock(spec=sessionmaker) + maker.return_value = mock_db_session + return maker + + @pytest.fixture + def file_service(self, mock_session_maker): + return FileService(session_factory=mock_session_maker) + + def test_init_with_engine(self): + engine = MagicMock(spec=Engine) + service = FileService(session_factory=engine) + assert isinstance(service._session_maker, sessionmaker) + + def test_init_with_sessionmaker(self): + maker = MagicMock(spec=sessionmaker) + service = FileService(session_factory=maker) + assert service._session_maker == maker + + def test_init_invalid_factory(self): + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + FileService(session_factory="invalid") + + @patch("services.file_service.storage") + @patch("services.file_service.naive_utc_now") + @patch("services.file_service.extract_tenant_id") + @patch("services.file_service.file_helpers.get_signed_file_url") + def test_upload_file_success( + self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service, mock_db_session + ): + # Setup + mock_tenant_id.return_value = "tenant_id" + mock_now.return_value = "2024-01-01" + mock_get_url.return_value = "http://signed-url" + + user = MagicMock(spec=Account) + user.id = "user_id" + content = b"file content" + filename = "test.jpg" + mimetype = "image/jpeg" + + # Execute + result = file_service.upload_file(filename=filename, content=content, mimetype=mimetype, user=user) + + # Assert + assert isinstance(result, UploadFile) + assert result.name == filename + assert result.tenant_id == "tenant_id" + assert result.size == len(content) + assert result.extension == "jpg" + assert result.mime_type == mimetype + assert result.created_by_role == CreatorUserRole.ACCOUNT + assert result.created_by == "user_id" + assert result.hash == hashlib.sha3_256(content).hexdigest() + assert result.source_url == "http://signed-url" + + mock_storage.save.assert_called_once() + mock_db_session.add.assert_called_once_with(result) + mock_db_session.commit.assert_called_once() + + def test_upload_file_invalid_characters(self, file_service): + with pytest.raises(ValueError, match="Filename contains invalid characters"): + file_service.upload_file(filename="invalid/file.txt", content=b"", mimetype="text/plain", user=MagicMock()) + + def test_upload_file_long_filename(self, file_service, mock_db_session): + # Setup + long_name = "a" * 210 + ".txt" + user = MagicMock(spec=Account) + user.id = "user_id" + + with ( + patch("services.file_service.storage"), + patch("services.file_service.extract_tenant_id") as mock_tenant, + patch("services.file_service.file_helpers.get_signed_file_url"), + ): + mock_tenant.return_value = "tenant" + result = file_service.upload_file(filename=long_name, content=b"test", mimetype="text/plain", user=user) + assert len(result.name) <= 205 # 200 + . + extension + assert result.name.endswith(".txt") + + def test_upload_file_blocked_extension(self, file_service): + with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe"): + with pytest.raises(BlockedFileExtensionError): + file_service.upload_file( + filename="test.exe", content=b"", mimetype="application/octet-stream", user=MagicMock() + ) + + def test_upload_file_unsupported_type_for_datasets(self, file_service): + with pytest.raises(UnsupportedFileTypeError): + file_service.upload_file( + filename="test.jpg", content=b"", mimetype="image/jpeg", user=MagicMock(), source="datasets" + ) + + def test_upload_file_too_large(self, file_service): + # 16MB file for an image with 15MB limit + content = b"a" * (16 * 1024 * 1024) + with patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 15): + with pytest.raises(FileTooLargeError): + file_service.upload_file(filename="test.jpg", content=content, mimetype="image/jpeg", user=MagicMock()) + + def test_upload_file_end_user(self, file_service, mock_db_session): + user = MagicMock(spec=EndUser) + user.id = "end_user_id" + + with ( + patch("services.file_service.storage"), + patch("services.file_service.extract_tenant_id") as mock_tenant, + patch("services.file_service.file_helpers.get_signed_file_url"), + ): + mock_tenant.return_value = "tenant" + result = file_service.upload_file(filename="test.txt", content=b"test", mimetype="text/plain", user=user) + assert result.created_by_role == CreatorUserRole.END_USER + + def test_is_file_size_within_limit(self): + with ( + patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 10), + patch.object(dify_config, "UPLOAD_VIDEO_FILE_SIZE_LIMIT", 20), + patch.object(dify_config, "UPLOAD_AUDIO_FILE_SIZE_LIMIT", 30), + patch.object(dify_config, "UPLOAD_FILE_SIZE_LIMIT", 5), + ): + # Image + assert FileService.is_file_size_within_limit(extension="jpg", file_size=10 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="png", file_size=11 * 1024 * 1024) is False + + # Video + assert FileService.is_file_size_within_limit(extension="mp4", file_size=20 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="avi", file_size=21 * 1024 * 1024) is False + + # Audio + assert FileService.is_file_size_within_limit(extension="mp3", file_size=30 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="wav", file_size=31 * 1024 * 1024) is False + + # Default + assert FileService.is_file_size_within_limit(extension="txt", file_size=5 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="pdf", file_size=6 * 1024 * 1024) is False + + def test_get_file_base64_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "test_key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load_once.return_value = b"test content" + + # Execute + result = file_service.get_file_base64("file_id") + + # Assert + assert result == base64.b64encode(b"test content").decode() + mock_storage.load_once.assert_called_once_with("test_key") + + def test_get_file_base64_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_base64("non_existent") + + def test_upload_text_success(self, file_service, mock_db_session): + # Setup + text = "sample text" + text_name = "test.txt" + user_id = "user_id" + tenant_id = "tenant_id" + + with patch("services.file_service.storage") as mock_storage: + # Execute + result = file_service.upload_text(text, text_name, user_id, tenant_id) + + # Assert + assert result.name == text_name + assert result.size == len(text) + assert result.tenant_id == tenant_id + assert result.created_by == user_id + assert result.used is True + assert result.extension == "txt" + mock_storage.save.assert_called_once() + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_upload_text_long_name(self, file_service, mock_db_session): + long_name = "a" * 210 + with patch("services.file_service.storage"): + result = file_service.upload_text("text", long_name, "user", "tenant") + assert len(result.name) == 200 + + def test_get_file_preview_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "pdf" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract: + mock_extract.return_value = "Extracted text content" + + # Execute + result = file_service.get_file_preview("file_id") + + # Assert + assert result == "Extracted text content" + + def test_get_file_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_preview("non_existent") + + def test_get_file_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "exe" + mock_db_session.query().where().first.return_value = upload_file + with pytest.raises(UnsupportedFileTypeError): + file_service.get_file_preview("file_id") + + def test_get_image_preview_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "jpg" + upload_file.mime_type = "image/jpeg" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with ( + patch("services.file_service.file_helpers.verify_image_signature") as mock_verify, + patch("services.file_service.storage") as mock_storage, + ): + mock_verify.return_value = True + mock_storage.load.return_value = iter([b"chunk1"]) + + # Execute + gen, mime = file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + # Assert + assert list(gen) == [b"chunk1"] + assert mime == "image/jpeg" + + def test_get_image_preview_invalid_sig(self, file_service): + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = False + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_image_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_image_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "txt" + mock_db_session.query().where().first.return_value = upload_file + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(UnsupportedFileTypeError): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_file_generator_by_file_id_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with ( + patch("services.file_service.file_helpers.verify_file_signature") as mock_verify, + patch("services.file_service.storage") as mock_storage, + ): + mock_verify.return_value = True + mock_storage.load.return_value = iter([b"chunk"]) + + gen, file = file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + assert list(gen) == [b"chunk"] + assert file == upload_file + + def test_get_file_generator_by_file_id_invalid_sig(self, file_service): + with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: + mock_verify.return_value = False + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + + def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + + def test_get_public_image_preview_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "png" + upload_file.mime_type = "image/png" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load.return_value = b"image content" + gen, mime = file_service.get_public_image_preview("file_id") + assert gen == b"image content" + assert mime == "image/png" + + def test_get_public_image_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_public_image_preview("file_id") + + def test_get_public_image_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "txt" + mock_db_session.query().where().first.return_value = upload_file + with pytest.raises(UnsupportedFileTypeError): + file_service.get_public_image_preview("file_id") + + def test_get_file_content_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load.return_value = b"hello world" + result = file_service.get_file_content("file_id") + assert result == "hello world" + + def test_get_file_content_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_content("file_id") + + def test_delete_file_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + # For session.scalar(select(...)) + mock_db_session.scalar.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + file_service.delete_file("file_id") + mock_storage.delete.assert_called_once_with("key") + mock_db_session.delete.assert_called_once_with(upload_file) + + def test_delete_file_not_found(self, file_service, mock_db_session): + mock_db_session.scalar.return_value = None + file_service.delete_file("file_id") + # Should return without doing anything + + @patch("services.file_service.db") + def test_get_upload_files_by_ids_empty(self, mock_db): + result = FileService.get_upload_files_by_ids("tenant_id", []) + assert result == {} + + @patch("services.file_service.db") + def test_get_upload_files_by_ids(self, mock_db): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "550e8400-e29b-41d4-a716-446655440000" + upload_file.tenant_id = "tenant_id" + mock_db.session.scalars().all.return_value = [upload_file] + + result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"]) + assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file + + def test_sanitize_zip_entry_name(self): + assert FileService._sanitize_zip_entry_name("path/to/file.txt") == "file.txt" + assert FileService._sanitize_zip_entry_name("../../../etc/passwd") == "passwd" + assert FileService._sanitize_zip_entry_name(" ") == "file" + assert FileService._sanitize_zip_entry_name("a\\b") == "a_b" + + def test_dedupe_zip_entry_name(self): + used = {"a.txt"} + assert FileService._dedupe_zip_entry_name("b.txt", used) == "b.txt" + assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (1).txt" + used.add("a (1).txt") + assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (2).txt" + + def test_build_upload_files_zip_tempfile(self): + upload_file = MagicMock(spec=UploadFile) + upload_file.name = "test.txt" + upload_file.key = "key" + + with ( + patch("services.file_service.storage") as mock_storage, + patch("services.file_service.os.remove") as mock_remove, + ): + mock_storage.load.return_value = [b"chunk1", b"chunk2"] + + with FileService.build_upload_files_zip_tempfile(upload_files=[upload_file]) as tmp_path: + assert os.path.exists(tmp_path) + + mock_remove.assert_called_once() diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index e64d3c5406..74139fd12d 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -1,97 +1,291 @@ from types import SimpleNamespace +from unittest.mock import MagicMock, patch import pytest +from sqlalchemy.engine import Engine +from configs import dify_config from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, + MemberRecipient, ) from dify_graph.runtime import VariablePool from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, + DeliveryTestEmailRecipient, DeliveryTestError, + DeliveryTestRegistry, + DeliveryTestResult, + DeliveryTestStatus, + DeliveryTestUnsupportedError, EmailDeliveryTestHandler, + HumanInputDeliveryTestService, + _build_form_link, ) -def _make_email_method() -> EmailDeliveryMethod: - return EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="tester@example.com")], - ), - subject="Test subject", - body="Test body", +@pytest.fixture +def mock_db(monkeypatch): + mock_db = MagicMock() + monkeypatch.setattr(service_module, "db", mock_db) + return mock_db + + +def _make_valid_email_config(): + return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body") + + +def test_build_form_link(): + with patch.object(dify_config, "APP_WEB_URL", "http://example.com/"): + assert _build_form_link("token123") == "http://example.com/form/token123" + + with patch.object(dify_config, "APP_WEB_URL", "http://example.com"): + assert _build_form_link("token123") == "http://example.com/form/token123" + + assert _build_form_link(None) is None + + with patch.object(dify_config, "APP_WEB_URL", None): + assert _build_form_link("token123") is None + + +class TestDeliveryTestRegistry: + def test_register(self): + registry = DeliveryTestRegistry() + assert len(registry._handlers) == 0 + handler = MagicMock() + registry.register(handler) + assert len(registry._handlers) == 1 + assert registry._handlers[0] == handler + + def test_register_and_dispatch(self): + handler = MagicMock() + handler.supports.return_value = True + handler.send_test.return_value = DeliveryTestResult(status=DeliveryTestStatus.OK) + + registry = DeliveryTestRegistry([handler]) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + result = registry.dispatch(context=context, method=method) + + assert result.status == DeliveryTestStatus.OK + handler.supports.assert_called_once_with(method) + handler.send_test.assert_called_once_with(context=context, method=method) + + def test_dispatch_unsupported(self): + handler = MagicMock() + handler.supports.return_value = False + + registry = DeliveryTestRegistry([handler]) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): + registry.dispatch(context=context, method=method) + + def test_default(self, mock_db): + registry = DeliveryTestRegistry.default() + assert len(registry._handlers) == 1 + assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) + + +def test_human_input_delivery_test_service(): + registry = MagicMock(spec=DeliveryTestRegistry) + service = HumanInputDeliveryTestService(registry=registry) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + service.send_test(context=context, method=method) + registry.dispatch.assert_called_once_with(context=context, method=method) + + +class TestEmailDeliveryTestHandler: + def test_init_with_engine(self): + engine = MagicMock(spec=Engine) + handler = EmailDeliveryTestHandler(session_factory=engine) + assert handler._session_factory.kw["bind"] == engine + + def test_supports(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + assert handler.supports(method) is True + assert handler.supports(MagicMock()) is False + + def test_send_test_unsupported_method(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + with pytest.raises(DeliveryTestUnsupportedError): + handler.send_test(context=MagicMock(), method=MagicMock()) + + def test_send_test_feature_disabled(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), ) - ) - - -def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - ) - method = _make_email_method() - - with pytest.raises(DeliveryTestError, match="Email delivery is not available"): - handler.send_test(context=context, method=method) - - -def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): - class DummyMail: - def __init__(self): - self.sent: list[dict[str, str]] = [] - - def is_inited(self) -> bool: - return True - - def send(self, *, to: str, subject: str, html: str): - self.sent.append({"to": to, "subject": subject, "html": html}) - - mail = DummyMail() - monkeypatch.setattr(service_module, "mail", mail) - monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template) - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment] - - method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]), - subject="Subject", - body="Value {{#node1.value#}}", + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" ) - ) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "OK") - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - variable_pool=variable_pool, - ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) - handler.send_test(context=context, method=method) + with pytest.raises(DeliveryTestError, match="Email delivery is not available"): + handler.send_test(context=context, method=method) - assert mail.sent[0]["html"] == "Value OK" + def test_send_test_mail_not_inited(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: False) + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" + ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + with pytest.raises(DeliveryTestError, match="Mail client is not initialized."): + handler.send_test(context=context, method=method) + + def test_send_test_no_recipients(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + handler._resolve_recipients = MagicMock(return_value=[]) + + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" + ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + with pytest.raises(DeliveryTestError, match="No recipients configured"): + handler.send_test(context=context, method=method) + + def test_send_test_success(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) + mock_mail_send = MagicMock() + monkeypatch.setattr(service_module.mail, "send", mock_mail_send) + monkeypatch.setattr(service_module, "render_email_template", lambda t, s: f"RENDERED_{t}") + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + handler._resolve_recipients = MagicMock(return_value=["test@example.com"]) + + variable_pool = VariablePool() + context = DeliveryTestContext( + tenant_id="t1", + app_id="a1", + node_id="n1", + node_title="title", + rendered_content="content", + variable_pool=variable_pool, + recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], + ) + + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + result = handler.send_test(context=context, method=method) + + assert result.status == DeliveryTestStatus.OK + assert result.delivered_to == ["test@example.com"] + mock_mail_send.assert_called_once() + args, kwargs = mock_mail_send.call_args + assert kwargs["to"] == "test@example.com" + assert "RENDERED_Subj" in kwargs["subject"] + + def test_resolve_recipients(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + + # Test Case 1: External Recipient + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False), + subject="", + body="", + ) + ) + assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] + + # Test Case 2: Member Recipient + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False), + subject="", + body="", + ) + ) + handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"}) + assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"] + + # Test Case 3: Whole Workspace + method = EmailDeliveryMethod( + config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="") + ) + handler._query_workspace_member_emails = MagicMock( + return_value={"u1": "u1@example.com", "u2": "u2@example.com"} + ) + recipients = handler._resolve_recipients(tenant_id="t1", method=method) + assert set(recipients) == {"u1@example.com", "u2@example.com"} + + def test_query_workspace_member_emails(self): + mock_session = MagicMock() + mock_session_factory = MagicMock(return_value=mock_session) + mock_session.__enter__.return_value = mock_session + + handler = EmailDeliveryTestHandler(session_factory=mock_session_factory) + + # Empty user_ids + assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {} + + # user_ids is None (all) + mock_execute = MagicMock() + mock_session.execute.return_value = mock_execute + mock_execute.all.return_value = [("u1", "u1@example.com")] + + result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) + assert result == {"u1": "u1@example.com"} + + # user_ids with values + result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"]) + assert result == {"u1": "u1@example.com"} + + def test_build_substitutions(self): + context = DeliveryTestContext( + tenant_id="t1", + app_id="a1", + node_id="n1", + node_title="title", + rendered_content="content", + template_vars={"custom": "var"}, + recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], + ) + + subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") + + assert subs["node_title"] == "title" + assert subs["form_content"] == "content" + assert subs["recipient_email"] == "test@example.com" + assert subs["custom"] == "var" + assert subs["form_token"] == "token123" + assert "form/token123" in subs["form_link"] + + # Without matching recipient + subs_no_match = EmailDeliveryTestHandler._build_substitutions( + context=context, recipient_email="other@example.com" + ) + assert subs_no_match["form_token"] == "" + assert subs_no_match["form_link"] == "" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index a4c6c50593..375e47d7fc 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -16,7 +16,13 @@ from dify_graph.nodes.human_input.entities import ( ) from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus from models.human_input import RecipientType -from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError +from services.human_input_service import ( + Form, + FormExpiredError, + FormSubmittedError, + HumanInputService, + InvalidFormDataError, +) @pytest.fixture @@ -285,3 +291,172 @@ def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_fa assert "Missing required inputs" in str(exc_info.value) repo.mark_submitted.assert_not_called() + + +def test_form_properties(sample_form_record): + form = Form(sample_form_record) + assert form.id == "form-id" + assert form.workflow_run_id == "workflow-run-id" + assert form.tenant_id == "tenant-id" + assert form.app_id == "app-id" + assert form.recipient_id == "recipient-id" + assert form.recipient_type == RecipientType.STANDALONE_WEB_APP + assert form.status == HumanInputFormStatus.WAITING + assert form.form_kind == HumanInputFormKind.RUNTIME + assert isinstance(form.created_at, datetime) + assert isinstance(form.expiration_time, datetime) + + +def test_form_submitted_error_init(): + error = FormSubmittedError(form_id="test-form") + assert "form_id=test-form" in error.description + assert error.code == 412 + + +def test_human_input_service_init_with_engine(mocker): + engine = MagicMock(spec=human_input_service_module.Engine) + sessionmaker_mock = mocker.patch("services.human_input_service.sessionmaker") + + HumanInputService(session_factory=engine) + sessionmaker_mock.assert_called_once_with(bind=engine) + + +def test_get_form_by_token_none(mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = None + + service = HumanInputService(session_factory, form_repository=repo) + assert service.get_form_by_token("invalid") is None + + +def test_get_form_definition_by_token_mismatch(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + service = HumanInputService(session_factory, form_repository=repo) + # RecipientType mismatch + assert service.get_form_definition_by_token(RecipientType.CONSOLE, "token") is None + + +def test_get_form_definition_by_token_success(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + service = HumanInputService(session_factory, form_repository=repo) + form = service.get_form_definition_by_token(RecipientType.STANDALONE_WEB_APP, "token") + assert form is not None + assert form.id == sample_form_record.form_id + + +def test_get_form_definition_by_token_for_console_mismatch(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record # is STANDALONE_WEB_APP + + service = HumanInputService(session_factory, form_repository=repo) + assert service.get_form_definition_by_token_for_console("token") is None + + +def test_submit_form_by_token_delivery_not_enabled(mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = None + + service = HumanInputService(session_factory, form_repository=repo) + with pytest.raises(human_input_service_module.WebAppDeliveryNotEnabledError): + service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "action", {}) + + +def test_submit_form_by_token_no_workflow_run_id(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + # Return record with no workflow_run_id + result_record = dataclasses.replace(sample_form_record, workflow_run_id=None) + repo.mark_submitted.return_value = result_record + + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "submit", {}) + enqueue_spy.assert_not_called() + + +def test_ensure_form_active_errors(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + # Submitted + submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + with pytest.raises(human_input_service_module.FormSubmittedError): + service.ensure_form_active(Form(submitted_record)) + + # Timeout status + timeout_record = dataclasses.replace(sample_form_record, status=HumanInputFormStatus.TIMEOUT) + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(timeout_record)) + + # Expired time + expired_time_record = dataclasses.replace( + sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1) + ) + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(expired_time_record)) + + +def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + + with pytest.raises(human_input_service_module.FormSubmittedError): + service._ensure_not_submitted(Form(submitted_record)) + + +def test_enqueue_resume_workflow_not_found(mocker, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = None + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + with pytest.raises(AssertionError) as excinfo: + service.enqueue_resume("workflow-run-id") + assert "WorkflowRun not found" in str(excinfo.value) + + +def test_enqueue_resume_app_not_found(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + session.execute.return_value.scalar_one_or_none.return_value = None + logger_spy = mocker.patch("services.human_input_service.logger") + + service.enqueue_resume("workflow-run-id") + logger_spy.error.assert_called_once() + + +def test_is_globally_expired_zero_timeout(monkeypatch, sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0) + assert service._is_globally_expired(Form(sample_form_record)) is False diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 3c38888753..4b8bdde46b 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -5,8 +5,13 @@ import pytest from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.model import App, AppMode, EndUser, Message -from services.errors.message import FirstMessageNotExistsError, LastMessageNotExistsError -from services.message_service import MessageService +from services.errors.message import ( + FirstMessageNotExistsError, + LastMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.message_service import MessageService, attach_message_extra_contents class TestMessageServiceFactory: @@ -244,14 +249,12 @@ class TestMessageServicePaginationByFirstId: mock_query_first = MagicMock() mock_query_history = MagicMock() + query_calls = [] + def query_side_effect(*args): if args[0] == Message: - # First call returns mock for first_message query - if not hasattr(query_side_effect, "call_count"): - query_side_effect.call_count = 0 - query_side_effect.call_count += 1 - - if query_side_effect.call_count == 1: + query_calls.append(args) + if len(query_calls) == 1: return mock_query_first else: return mock_query_history @@ -647,3 +650,410 @@ class TestMessageServicePaginationByLastId: assert len(result.data) == 10 # Last message trimmed assert result.has_more is True assert result.limit == 10 + + +class TestMessageServiceUtilities: + """Unit tests for MessageService module-level utility functions.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 16: attach_message_extra_contents with empty list + def test_attach_message_extra_contents_empty(self): + """Test attach_message_extra_contents with empty list does nothing.""" + # Act & Assert (should not raise error) + attach_message_extra_contents([]) + + # Test 17: attach_message_extra_contents with messages + @patch("services.message_service._create_execution_extra_content_repository") + def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory): + """Test attach_message_extra_contents correctly attaches content.""" + # Arrange + messages = [factory.create_message_mock(message_id="msg-1"), factory.create_message_mock(message_id="msg-2")] + + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + + # Mock extra content models + mock_content1 = MagicMock() + mock_content1.model_dump.return_value = {"key": "value1"} + mock_content2 = MagicMock() + mock_content2.model_dump.return_value = {"key": "value2"} + + mock_repo.get_by_message_ids.return_value = [[mock_content1], [mock_content2]] + + # Act + attach_message_extra_contents(messages) + + # Assert + mock_repo.get_by_message_ids.assert_called_once_with(["msg-1", "msg-2"]) + messages[0].set_extra_contents.assert_called_once_with([{"key": "value1"}]) + messages[1].set_extra_contents.assert_called_once_with([{"key": "value2"}]) + + # Test 18: attach_message_extra_contents with index out of bounds + @patch("services.message_service._create_execution_extra_content_repository") + def test_attach_message_extra_contents_index_out_of_bounds(self, mock_create_repo, factory): + """Test attach_message_extra_contents handles missing content lists.""" + # Arrange + messages = [factory.create_message_mock(message_id="msg-1")] + + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + mock_repo.get_by_message_ids.return_value = [] # Empty returned list + + # Act + attach_message_extra_contents(messages) + + # Assert + messages[0].set_extra_contents.assert_called_once_with([]) + + # Test 19: _create_execution_extra_content_repository + @patch("services.message_service.db") + @patch("services.message_service.sessionmaker") + @patch("services.message_service.SQLAlchemyExecutionExtraContentRepository") + def test_create_execution_extra_content_repository(self, mock_repo_class, mock_sessionmaker, mock_db): + """Test _create_execution_extra_content_repository creates expected repository.""" + from services.message_service import _create_execution_extra_content_repository + + # Act + _create_execution_extra_content_repository() + + # Assert + mock_sessionmaker.assert_called_once() + mock_repo_class.assert_called_once() + + +class TestMessageServiceGetMessage: + """Unit tests for MessageService.get_message method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 20: get_message success for EndUser + @patch("services.message_service.db") + def test_get_message_end_user_success(self, mock_db, factory): + """Test get_message returns message for EndUser.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock(user_id="end-user-123") + message = factory.create_message_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = message + + # Act + result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + # Assert + assert result == message + mock_query.where.assert_called_once() + + # Test 21: get_message success for Account (Admin) + @patch("services.message_service.db") + def test_get_message_account_success(self, mock_db, factory): + """Test get_message returns message for Account.""" + # Arrange + from models import Account + + app = factory.create_app_mock() + user = MagicMock(spec=Account) + user.id = "account-123" + message = factory.create_message_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = message + + # Act + result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + # Assert + assert result == message + + # Test 22: get_message not found + @patch("services.message_service.db") + def test_get_message_not_found(self, mock_db, factory): + """Test get_message raises MessageNotExistsError when not found.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act & Assert + with pytest.raises(MessageNotExistsError): + MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + +class TestMessageServiceFeedback: + """Unit tests for MessageService feedback-related methods.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 23: create_feedback - new feedback for EndUser + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory): + """Test creating new feedback for an end user.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + message = factory.create_message_mock() + message.user_feedback = None + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating="like", + content="Good answer", + ) + + # Assert + assert result.rating == "like" + assert result.content == "Good answer" + assert result.from_source == "user" + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + # Test 24: create_feedback - update feedback for Account + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_update_account(self, mock_get_message, mock_db, factory): + """Test updating existing feedback for an account.""" + # Arrange + from models import Account, MessageFeedback + + app = factory.create_app_mock() + user = MagicMock(spec=Account) + user.id = "account-123" + message = factory.create_message_mock() + feedback = MagicMock(spec=MessageFeedback) + message.admin_feedback = feedback + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating="dislike", + content="Bad answer", + ) + + # Assert + assert result == feedback + assert feedback.rating == "dislike" + assert feedback.content == "Bad answer" + mock_db.session.commit.assert_called_once() + + # Test 25: create_feedback - delete feedback (rating is None) + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_delete(self, mock_get_message, mock_db, factory): + """Test deleting feedback by passing rating=None.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + message = factory.create_message_mock() + feedback = MagicMock() + message.user_feedback = feedback + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating=None, + content=None, + ) + + # Assert + assert result == feedback + mock_db.session.delete.assert_called_once_with(feedback) + mock_db.session.commit.assert_called_once() + + # Test 26: get_all_messages_feedbacks + @patch("services.message_service.db") + def test_get_all_messages_feedbacks(self, mock_db, factory): + """Test get_all_messages_feedbacks returns list of dicts.""" + # Arrange + app = factory.create_app_mock() + feedback = MagicMock() + feedback.to_dict.return_value = {"id": "fb-1"} + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.limit.return_value = mock_query + mock_query.offset.return_value = mock_query + mock_query.all.return_value = [feedback] + + # Act + result = MessageService.get_all_messages_feedbacks(app_model=app, page=1, limit=10) + + # Assert + assert result == [{"id": "fb-1"}] + mock_query.limit.assert_called_with(10) + mock_query.offset.assert_called_with(0) + + +class TestMessageServiceSuggestedQuestions: + """Unit tests for MessageService.get_suggested_questions_after_answer method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 27: get_suggested_questions_after_answer - user is None + def test_get_suggested_questions_user_none(self, factory): + app = factory.create_app_mock() + with pytest.raises(ValueError, match="user cannot be None"): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=None, message_id="msg-123", invoke_from=MagicMock() + ) + + # Test 28: get_suggested_questions_after_answer - Advanced Chat success + @patch("services.message_service.ModelManager") + @patch("services.message_service.WorkflowService") + @patch("services.message_service.AdvancedChatAppConfigManager") + @patch("services.message_service.TokenBufferMemory") + @patch("services.message_service.LLMGenerator") + @patch("services.message_service.TraceQueueManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_advanced_chat_success( + self, + mock_conversation_service, + mock_get_message, + mock_trace_manager, + mock_llm_gen, + mock_memory, + mock_config_manager, + mock_workflow_service, + mock_model_manager, + factory, + ): + """Test successful suggested questions generation in Advanced Chat mode.""" + from core.app.entities.app_invoke_entities import InvokeFrom + + # Arrange + app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value) + user = factory.create_end_user_mock() + message = factory.create_message_mock() + mock_get_message.return_value = message + + workflow = MagicMock() + mock_workflow_service.return_value.get_published_workflow.return_value = workflow + + app_config = MagicMock() + app_config.additional_features.suggested_questions_after_answer = True + mock_config_manager.get_app_config.return_value = app_config + + mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] + + # Act + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=InvokeFrom.WEB_APP + ) + + # Assert + assert result == ["Q1?"] + mock_workflow_service.return_value.get_published_workflow.assert_called_once() + mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() + + # Test 29: get_suggested_questions_after_answer - Chat app success (no override) + @patch("services.message_service.db") + @patch("services.message_service.ModelManager") + @patch("services.message_service.TokenBufferMemory") + @patch("services.message_service.LLMGenerator") + @patch("services.message_service.TraceQueueManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_chat_app_success( + self, + mock_conversation_service, + mock_get_message, + mock_trace_manager, + mock_llm_gen, + mock_memory, + mock_model_manager, + mock_db, + factory, + ): + """Test successful suggested questions generation in basic Chat mode.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.CHAT.value) + user = factory.create_end_user_mock() + message = factory.create_message_mock() + mock_get_message.return_value = message + + conversation = MagicMock() + conversation.override_model_configs = None + mock_conversation_service.get_conversation.return_value = conversation + + app_model_config = MagicMock() + app_model_config.suggested_questions_after_answer_dict = {"enabled": True} + app_model_config.model_dict = {"provider": "openai", "name": "gpt-4"} + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = app_model_config + + mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] + + # Act + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock() + ) + + # Assert + assert result == ["Q1?"] + mock_query.first.assert_called_once() + mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() + + # Test 30: get_suggested_questions_after_answer - Disabled Error + @patch("services.message_service.WorkflowService") + @patch("services.message_service.AdvancedChatAppConfigManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_disabled_error( + self, mock_conversation_service, mock_get_message, mock_config_manager, mock_workflow_service, factory + ): + """Test SuggestedQuestionsAfterAnswerDisabledError is raised when feature is disabled.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value) + user = factory.create_end_user_mock() + mock_get_message.return_value = factory.create_message_mock() + + workflow = MagicMock() + mock_workflow_service.return_value.get_published_workflow.return_value = workflow + + app_config = MagicMock() + app_config.additional_features.suggested_questions_after_answer = False + mock_config_manager.get_app_config.return_value = app_config + + # Act & Assert + with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock() + )